diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py new file mode 100644 index 0000000000000..75ad094fa1382 --- /dev/null +++ b/.buildkite/check-wheel-size.py @@ -0,0 +1,36 @@ +import os +import zipfile + +MAX_SIZE_MB = 200 + + +def print_top_10_largest_files(zip_file): + with zipfile.ZipFile(zip_file, 'r') as z: + file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] + file_sizes.sort(key=lambda x: x[1], reverse=True) + for f, size in file_sizes[:10]: + print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + + +def check_wheel_size(directory): + for root, _, files in os.walk(directory): + for f in files: + if f.endswith(".whl"): + wheel_path = os.path.join(root, f) + wheel_size = os.path.getsize(wheel_path) + wheel_size_mb = wheel_size / (1024 * 1024) + if wheel_size_mb > MAX_SIZE_MB: + print( + f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " + f"compare to the allowed size ({MAX_SIZE_MB} MB).") + print_top_10_largest_files(wheel_path) + return 1 + else: + print(f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb} MB).") + return 0 + + +if __name__ == "__main__": + import sys + sys.exit(check_wheel_size(sys.argv[1])) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 83a56e25aca73..bde8ab6184d3c 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,38 +1,73 @@ -# This script build the ROCm docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. +# This script runs test inside the corresponding ROCm docker container. set -ex # Print ROCm version +echo "--- ROCm info" rocminfo -# Try building the docker image -docker build -t rocm -f Dockerfile.rocm . +# cleanup older docker images +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes + docker volume prune -f + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} -# Setup cleanup -remove_docker_container() { docker rm -f rocm || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image -docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server & - -# Wait for the server to start -wait_for_server_to_start() { - timeout=300 - counter=0 - - while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do - sleep 1 - counter=$((counter + 1)) - if [ $counter -ge $timeout ]; then - echo "Timeout after $timeout seconds" - break +# Call the cleanup docker function +cleanup_docker + +echo "--- Resetting GPUs" + +echo "reset" > /opt/amdgpu/etc/gpu_state + +while true; do + sleep 3 + if grep -q clean /opt/amdgpu/etc/gpu_state; then + echo "GPUs state is \"clean\"" + break fi - done +done + +echo "--- Building container" +sha=$(git rev-parse --short HEAD) +image_name=rocm_${sha} +container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) +docker build \ + -t ${image_name} \ + -f Dockerfile.rocm \ + --progress plain \ + . + +remove_docker_container() { + docker rm -f ${container_name} || docker image rm -f ${image_name} || true } -wait_for_server_to_start +trap remove_docker_container EXIT + +echo "--- Running container" + +docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --rm \ + -e HF_TOKEN \ + --name ${container_name} \ + ${image_name} \ + /bin/bash -c "${@}" -# Test a simple prompt -curl -X POST -H "Content-Type: application/json" \ - localhost:8000/generate \ - -d '{"prompt": "San Francisco is a"}' diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index f6a542afe1a3d..1efc96395933f 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." (which wget && which curl) || (apt-get update && apt-get install -y wget curl) # run python-based benchmarks and upload the result to buildkite -python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt +python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? -python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite @@ -53,6 +53,11 @@ echo '```' >> benchmark_results.md tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines echo '```' >> benchmark_results.md +# if the agent binary is not found, skip uploading the results, exit 0 +if [ ! -f /workspace/buildkite-agent ]; then + exit 0 +fi + # upload the results to buildkite /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md @@ -69,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi -/workspace/buildkite-agent artifact upload openai-*.json +rm ShareGPT_V3_unfiltered_cleaned_split.json +/workspace/buildkite-agent artifact upload "*.json" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f187d1f181724..414045fe163e5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -11,4 +11,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh new file mode 100644 index 0000000000000..252c0f7fecd12 --- /dev/null +++ b/.buildkite/run-neuron-test.sh @@ -0,0 +1,51 @@ +# This script build the Neuron docker image and run the API server inside the container. +# It serves a sanity check for compilation and basic model usage. +set -e + +# Try building the docker image +aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com + +# prune old image and containers to save disk space, and only once a day +# by using a timestamp file in tmp. +if [ -f /tmp/neuron-docker-build-timestamp ]; then + last_build=$(cat /tmp/neuron-docker-build-timestamp) + current_time=$(date +%s) + if [ $((current_time - last_build)) -gt 86400 ]; then + docker system prune -f + echo $current_time > /tmp/neuron-docker-build-timestamp + fi +else + echo $(date +%s) > /tmp/neuron-docker-build-timestamp +fi + +docker build -t neuron -f Dockerfile.neuron . + +# Setup cleanup +remove_docker_container() { docker rm -f neuron || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image +docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \ + --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 & + +# Wait for the server to start +wait_for_server_to_start() { + timeout=300 + counter=0 + + while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do + sleep 1 + counter=$((counter + 1)) + if [ $counter -ge $timeout ]; then + echo "Timeout after $timeout seconds" + break + fi + done +} +wait_for_server_to_start + +# Test a simple prompt +curl -X POST -H "Content-Type: application/json" \ + localhost:8000/generate \ + -d '{"prompt": "San Francisco is a"}' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a6..21cbd9ba13780 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -5,89 +5,155 @@ steps: - label: Regression Test + mirror_hardwares: [amd] command: pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: AsyncEngine Test + #mirror_hardwares: [amd] command: pytest -v -s async_engine - label: Basic Correctness Test - command: pytest -v -s basic_correctness + mirror_hardwares: [amd] + commands: + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test + mirror_hardwares: [amd] command: pytest -v -s core - label: Distributed Comm Ops Test - command: pytest -v -s test_comm_ops.py - working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + #mirror_hardwares: [amd] + command: pytest -v -s distributed/test_comm_ops.py + working_dir: "/vllm-workspace/tests" + num_gpus: 2 - label: Distributed Tests - working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + commands: + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s spec_decode/e2e/test_integration_dist.py + +- label: Distributed Tests (Multiple Groups) + #mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" + num_gpus: 4 commands: - - pytest -v -s test_pynccl.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - pytest -v -s distributed/test_pynccl.py - label: Engine Test - command: pytest -v -s engine tokenization test_sequence.py test_config.py + mirror_hardwares: [amd] + command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test + mirror_hardwares: [amd] + commands: - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s test_inputs.py + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" + mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - - pip install awscli + # install tensorizer for tensorize_vllm_model.py + - pip install awscli tensorizer - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py + - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N + #mirror_hardwares: [amd] command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Models Test + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py + - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py - label: Prefix Caching Test + mirror_hardwares: [amd] commands: - pytest -v -s prefix_caching - label: Samplers Test + #mirror_hardwares: [amd] command: pytest -v -s samplers - label: LogitsProcessor Test + mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test + mirror_hardwares: [amd] command: pytest -v -s worker - label: Speculative decoding tests + #mirror_hardwares: [amd] command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + #mirror_hardwares: [amd] + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: LoRA Long Context (Distributed) + #mirror_hardwares: [amd] + num_gpus: 4 + # This test runs llama 13B, so it is required to run on 4 GPUs. + commands: + # Temporarily run this way because we cannot clean up GPU mem usage + # for multi GPU tests. + # TODO(sang): Fix it. + - pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced + - pytest -v -s lora/test_long_context.py::test_batched_rope_kernel + - pytest -v -s lora/test_long_context.py::test_self_consistency + - pytest -v -s lora/test_long_context.py::test_quality + - pytest -v -s lora/test_long_context.py::test_max_len + +- label: Tensorizer Test + #mirror_hardwares: [amd] + command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader + - label: Metrics Test + mirror_hardwares: [amd] command: pytest -v -s metrics +- label: Quantization Test + #mirror_hardwares: [amd] + command: pytest -v -s quantization + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" + mirror_hardwares: [amd] commands: - pip install aiohttp - bash run-benchmarks.sh diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 3ed23c62c005d..265833e2ccf6e 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -3,16 +3,8 @@ {% set default_working_dir = "/vllm-workspace/tests" %} steps: - - label: "AMD Test" - agents: - queue: amd - command: bash .buildkite/run-amd-test.sh - - - label: "CPU Test" - command: bash .buildkite/run-cpu-test.sh - - label: ":docker: build image" - commands: + commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: @@ -21,8 +13,35 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 - wait + - group: "AMD Tests" + depends_on: ~ + steps: + {% for step in steps %} + {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + agents: + queue: amd + command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}" + env: + DOCKER_BUILDKIT: "1" + {% endif %} + {% endfor %} + + - label: "Neuron Test" + depends_on: ~ + agents: + queue: neuron + command: bash .buildkite/run-neuron-test.sh + soft_fail: true + + - label: "Intel Test" + depends_on: ~ + command: bash .buildkite/run-cpu-test.sh + {% for step in steps %} - label: "{{ step.label }}" agents: @@ -35,9 +54,14 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 plugins: - kubernetes: podSpec: + {% if step.num_gpus %} + priorityClassName: gpu-priority-cls-{{ step.num_gpus }} + {% endif %} volumes: - name: dshm emptyDir: diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..7f9e6d720fae5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml index 4c6c96187cc6c..df41ade8c3c01 100644 --- a/.github/ISSUE_TEMPLATE/200-installation.yml +++ b/.github/ISSUE_TEMPLATE/200-installation.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml index 88227b4b2e7b9..54763af1058f6 100644 --- a/.github/ISSUE_TEMPLATE/300-usage.yml +++ b/.github/ISSUE_TEMPLATE/300-usage.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index f1124dfa78bbc..ce980c3f4a01d 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` @@ -57,6 +58,10 @@ body: If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + + Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. + + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance discussion.yml index 9e8e7b4aa3530..4f8843420a94e 100644 --- a/.github/ISSUE_TEMPLATE/700-performance discussion.yml +++ b/.github/ISSUE_TEMPLATE/700-performance discussion.yml @@ -39,6 +39,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml new file mode 100644 index 0000000000000..5382b124dcd79 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -0,0 +1,49 @@ +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference. +- type: textarea + attributes: + label: Motivation. + description: > + The motivation of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Proposed Change. + description: > + The proposed change of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 0000000000000..e9b6e28fa6bcb --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,42 @@ +name: clang-format + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install clang-format==18.1.5 + - name: Running clang-format + run: | + EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' + ) + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ + | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000000000..a20753d8a7702 --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,50 @@ +name: mypy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.9.0 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | + mypy vllm/attention --config-file pyproject.toml + mypy vllm/core --config-file pyproject.toml + mypy vllm/distributed --config-file pyproject.toml + mypy vllm/entrypoints --config-file pyproject.toml + mypy vllm/executor --config-file pyproject.toml + mypy vllm/usage --config-file pyproject.toml + mypy vllm/*.py --config-file pyproject.toml + mypy vllm/transformers_utils --config-file pyproject.toml + mypy vllm/engine --config-file pyproject.toml + mypy vllm/worker --config-file pyproject.toml + mypy vllm/spec_decode --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + mypy vllm/logging --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml + diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index fc97e33c19af2..9c35ede5f6781 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,13 +49,19 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt. + pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: - name: Checkout uses: actions/checkout@v3 + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + create-symlink: true + key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} + - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} run: | @@ -76,6 +82,8 @@ jobs: - name: Build wheel shell: bash + env: + CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size run: | bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index e8060e369a889..e71033f828006 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js index 0f25624b4c21c..475742118afeb 100644 --- a/.github/workflows/scripts/create_release.js +++ b/.github/workflows/scripts/create_release.js @@ -8,7 +8,7 @@ module.exports = async (github, context, core) => { generate_release_notes: true, name: process.env.RELEASE_TAG, owner: context.repo.owner, - prerelease: false, + prerelease: true, repo: context.repo.repo, tag_name: process.env.RELEASE_TAG, }); diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index b163c960db555..04f307bcf8b0e 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index b1513ef0ddb0c..e077366d1e4a1 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,8 @@ instance/ # Sphinx documentation docs/_build/ +docs/source/getting_started/examples/*.rst +!**/*.template.rst # PyBuilder .pybuilder/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b6ea4b570a99..ad562d9c996f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") @@ -174,17 +174,52 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/convert_kernel.cu" - "csrc/quantization/fp8/gemm_kernel.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_EXT_SRC + "csrc/quantization/fp8/amd/gemm_kernel.cu") +endif() + if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC + "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/marlin_cuda_kernel.cu" - "csrc/custom_all_reduce.cu") + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + endif() define_gpu_extension_target( @@ -194,6 +229,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} WITH_SOABI) @@ -240,24 +276,13 @@ define_gpu_extension_target( set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu" - "csrc/punica/punica_ops.cc") + "csrc/punica/punica_ops.cu" + "csrc/punica/punica_pybind.cpp") # # Copy GPU compilation flags+update for punica @@ -281,6 +306,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +elseif(${VLLM_GPU_LANG} STREQUAL "HIP") + set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) + message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -317,10 +345,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") add_dependencies(default _C) add_dependencies(default _custom_C) message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) -endif() + add_dependencies(default _moe_C) -if(VLLM_GPU_LANG STREQUAL "CUDA") # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and # there are supported target arches. @@ -330,3 +356,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") add_dependencies(default _punica_C) endif() endif() + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) +endif() diff --git a/Dockerfile b/Dockerfile index d1d29177b0f44..eb96bf3c1db2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,13 @@ # The vLLM Dockerfile is used to construct vLLM image that can be directly used # to run the OpenAI compatible server. +# Please update any changes made here to +# docs/source/dev/dockerfile/dockerfile.rst and +# docs/source/assets/dev/dockerfile-stages-dependency.png + #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev RUN apt-get update -y \ && apt-get install -y python3-pip git @@ -12,7 +16,7 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ WORKDIR /workspace @@ -71,34 +75,15 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ python3 setup.py bdist_wheel --dist-dir=dist -# the `vllm_nccl` package must be installed from source distribution -# pip is too smart to store a wheel in the cache, and other CI jobs -# will directly use the wheel from the cache, which is not what we want. -# we need to remove it manually -RUN --mount=type=cache,target=/root/.cache/pip \ - pip cache remove vllm_nccl* -#################### EXTENSION Build IMAGE #################### - -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.6 -ENV FLASH_ATTN_VERSION=${flash_attn_version} +# check the size of the wheel, we cannot upload wheels larger than 100MB +COPY .buildkite/check-wheel-size.py check-wheel-size.py +RUN python3 check-wheel-size.py dist -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### +#################### EXTENSION Build IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base WORKDIR /vllm-workspace RUN apt-get update -y \ @@ -108,16 +93,12 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4251fddd6cc3b..aec79824213f3 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -17,4 +17,6 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +WORKDIR /workspace/ + CMD ["/bin/bash"] diff --git a/Dockerfile.neuron b/Dockerfile.neuron new file mode 100644 index 0000000000000..fe42b4ef393f1 --- /dev/null +++ b/Dockerfile.neuron @@ -0,0 +1,36 @@ +# default base image +ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04" + +FROM $BASE_IMAGE + +RUN echo "Base image is $BASE_IMAGE" + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas +RUN python3 -m pip install sentencepiece transformers==4.36.2 -U +RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U + +COPY ./vllm /app/vllm/vllm +COPY ./setup.py /app/vllm/setup.py +COPY ./requirements-common.txt /app/vllm/requirements-common.txt +COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt + +RUN cd /app/vllm \ + && python3 -m pip install -U -r requirements-neuron.txt + +ENV VLLM_BUILD_WITH_NEURON 1 +RUN cd /app/vllm \ + && pip install -e . \ + && cd .. + +CMD ["/bin/bash"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 04102560452a1..6d0dd31d346f1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -153,7 +153,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/gradlib/dist/*.whl / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/rocm_patch /rocm_patch COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt / -COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/patch_xformers.rocm.sh / # ----------------------- # Final vLLM image @@ -199,14 +198,13 @@ RUN --mount=type=bind,from=export_triton,src=/,target=/install \ fi RUN python3 -m pip install --upgrade numba -RUN python3 -m pip install xformers==0.0.23 --no-deps # Install vLLM (and gradlib) +# Make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ cd /install \ && pip install -U -r requirements-rocm.txt \ - && if [ "$BUILD_FA" = "1" ]; then \ - bash patch_xformers.rocm.sh; fi \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-6.0"*) \ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ @@ -215,11 +213,9 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ *) ;; esac \ && pip install *.whl -# Update Ray to latest version + set environment variable to ensure it works on TP > 1 -RUN python3 -m pip install --no-cache-dir 'ray[all]>=2.10.0' ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 -# HIPgraph performance environment variable. +# Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 CMD ["/bin/bash"] diff --git a/MANIFEST.in b/MANIFEST.in index d385f194c6c0f..82be639ef4d73 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,9 @@ include LICENSE include requirements-common.txt include requirements-cuda.txt +include requirements-rocm.txt +include requirements-neuron.txt +include requirements-cpu.txt include CMakeLists.txt recursive-include cmake * diff --git a/README.md b/README.md index d53227b82d87a..d63819c3815c0 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,17 @@ Easy, fast, and cheap LLM serving for everyone

+--- + +**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)** + +We are thrilled to announce our fourth vLLM Meetup! +The vLLM team will share recent updates and roadmap. +We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM. +Please register [here](https://lu.ma/agivllm) and join us! + +--- + *Latest News* 🔥 - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). @@ -51,40 +62,14 @@ vLLM is flexible and easy to use with: - (Experimental) Prefix caching support - (Experimental) Multi-lora support -vLLM seamlessly supports many Hugging Face models, including the following architectures: - -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) -- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) -- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) -- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) -- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) -- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) -- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) -- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) -- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) -- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) -- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) -- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) +vLLM seamlessly supports most popular open-source models on HuggingFace, including: +- Transformer-like LLMs (e.g., Llama) +- Mixture-of-Expert LLMs (e.g., Mixtral) +- Multi-modal LLMs (e.g., LLaVA) + +Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). + +## Getting Started Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): @@ -92,9 +77,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get pip install vllm ``` -## Getting Started - -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. +Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) @@ -104,6 +87,32 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +## Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Dropbox +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. + ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): diff --git a/ROCm_performance.md b/ROCm_performance.md index 0f12ed1adc9af..180c848a21950 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -18,30 +18,3 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. - -## Fp8 Quantization - -To use fp8 quantization, first step is to quantize your model to fp8 format. Generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be added under your model folder along with a file called `serenity_config.json`, which contains a json object with a key: `"quantized_weights": "quantized/osf/rank0.safetensors"`, the value should be the relative path of your safetensor file containing the quantized weights. - -Then we can run a model with fp8 quantization using vllm, just add a parameter `quantization="fp8"` when creating the `vllm.LLM` object. - -## Gemm Tunning for Fp8 - -To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model. - -To obtain all the shapes of gemms during the execution of the model, set the env value TUNE_FP8=1 and the run the model as usual. We will get the a file called `/fp8_shapes.csv`. - -Next, run gradlib to obtain the best solutions of these shapes: - -``` -cd gradlib_fp8 -python3 -m pip uninstall gradlib -python3 setup.py install -python3 gemm_tuner.py --input_file /fp8_shapes.csv --tuned_file /tuned_fp8_16.csv -cd ../gradlib -python3 -m pip uninstall gradlib -python3 setup.py install -cd .. -``` - -Now, when running inference with fp8, we are using the tunned gemm for best performance. \ No newline at end of file diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ad428bd1c3644..58dcc6167efa6 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -27,8 +27,8 @@ class RequestFuncInput: class RequestFuncOutput: generated_text: str = "" success: bool = False - latency: float = 0 - ttft: float = 0 # Time to first token + latency: float = 0.0 + ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies prompt_len: int = 0 @@ -58,23 +58,24 @@ async def async_request_tgi( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -88,6 +89,9 @@ async def async_request_tgi( output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -119,23 +123,25 @@ async def async_request_trt_llm( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) + output.generated_text += data["text_output"] timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -147,11 +153,10 @@ async def async_request_trt_llm( most_recent_timestamp = timestamp output.latency = most_recent_timestamp - st - output.generated_text = json.loads(data)["text_output"] output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -195,7 +200,7 @@ async def async_request_deepspeed_mii( output.generated_text = parsed_resp["text"][0] output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -234,19 +239,20 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -255,7 +261,7 @@ async def async_request_openai_completions( if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -273,6 +279,9 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -315,19 +324,20 @@ async def async_request_openai_chat_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -337,7 +347,7 @@ async def async_request_openai_chat_completions( delta = data["choices"][0]["delta"] if delta.get("content", None): # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -354,7 +364,7 @@ async def async_request_openai_chat_completions( output.success = True output.latency = latency else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a160e4cabf3f8..2aca1b23f9b6f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,14 +1,17 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import json import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS def main(args: argparse.Namespace): @@ -17,6 +20,8 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, @@ -28,9 +33,11 @@ def main(args: argparse.Namespace): device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, worker_use_ray=args.worker_use_ray, + use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, - block_size=args.block_size) + block_size=args.block_size, + gpu_memory_utilization=args.gpu_memory_utilization) sampling_params = SamplingParams( n=args.n, @@ -44,7 +51,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -55,13 +64,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() @@ -93,16 +102,28 @@ def run_to_completion(profile_dir: Optional[str] = None): for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == '__main__': parser = argparse.ArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=[*QUANTIZATION_METHODS, None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) @@ -137,15 +158,13 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, @@ -178,10 +197,10 @@ def run_to_completion(profile_dir: Optional[str] = None): help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', @@ -197,5 +216,16 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 1f3274a28cad5..089966986984f 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): def main(args): - llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", + llm = LLM(model=args.model, tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) num_prompts = 100 prompts = [PROMPT] * num_prompts - sampling_params = SamplingParams(temperature=0, max_tokens=100) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) print("------warm up------") test_prefix( llm=llm, - prompts=prompts[:1], + prompts=prompts, sampling_params=sampling_params, ) @@ -45,8 +47,16 @@ def main(args): parser = argparse.ArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') + parser.add_argument('--model', + type=str, + default='baichuan-inc/Baichuan2-13B-Chat') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6054df439fa57..f3d71de775f82 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -17,6 +17,10 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. """ import argparse import asyncio @@ -27,7 +31,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -58,7 +62,11 @@ def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -68,38 +76,32 @@ def sample_sharegpt_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # some of these will be filtered out, so sample more than we need - sampled_indices = random.sample(range(len(dataset)), - int(num_requests * 1.2)) - dataset = [dataset[i] for i in sampled_indices] - - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + # Shuffle the dataset. + random.shuffle(dataset) - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. - # This is because TGI causes errors when the input or output length - # is too short. continue if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def sample_sonnet_requests( @@ -213,6 +215,11 @@ def calculate_metrics( else: actual_output_lens.append(0) + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -224,9 +231,9 @@ def calculate_metrics( 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, - mean_tpot_ms=np.mean(tpots) * 1000, - median_tpot_ms=np.median(tpots) * 1000, - p99_tpot_ms=np.percentile(tpots, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, ) return metrics, actual_output_lens @@ -248,6 +255,24 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -361,6 +386,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sharegpt": @@ -368,6 +394,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sonnet": @@ -524,6 +551,12 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") parser.add_argument( "--sonnet-input-len", type=int, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 9248010e38063..fdfbf23c721b3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,6 +10,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + def sample_requests( dataset_path: str, @@ -74,48 +76,51 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, worker_use_ray: bool = False, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - worker_use_ray=worker_use_ray, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + worker_use_ray=worker_use_ray, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. + prompts = [] + sampling_params = [] for prompt, _, output_len in requests: - sampling_params = SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) start = time.perf_counter() - # FIXME(woosuk): Do not use internal method. - llm._run_engine(use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() return end - start @@ -221,7 +226,8 @@ def main(args: argparse.Namespace): args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, - args.enable_prefix_caching, args.gpu_memory_utilization, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, args.worker_use_ray, args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 @@ -238,6 +244,18 @@ def main(args: argparse.Namespace): print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the throughput.") @@ -262,7 +280,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=[*QUANTIZATION_METHODS, None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", @@ -307,15 +325,13 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=["auto", "fp8"], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, @@ -341,11 +357,24 @@ def main(args: argparse.Namespace): help='use Ray for distributed serving, will be ' 'automatically set when using more than 1 GPU ' 'unless on ROCm where the default is torchrun') + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py new file mode 100644 index 0000000000000..59392947b15c8 --- /dev/null +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -0,0 +1,302 @@ +import argparse +import os +import sys +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.aqlm import ( + dequantize_weight, generic_dequantize_gemm, get_int_dtype, + optimized_dequantize_gemm) + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def torch_mult( + input: torch.Tensor, # [..., in_features] + weights: torch.Tensor, + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] +) -> torch.Tensor: + output = F.linear(input, weights) + return output + + +def dequant_out_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return flattened_output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_weight_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_no_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + return F.linear(input, weights, bias) + + +# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against +# the generic pytorch version. +# Just visual comparison. +def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + count = 0 + for index in range(16): + for i in range(8): + for book in range(nbooks): + codebooks[book, index, 0, i] = count * (10**book) + count += 1 + + print("codes shape", codes.shape) + + for i in range(16): + for book in range(nbooks): + codes[0, i, book] = i + codes[0, -i, book] = i + + weights = dequantize_weight(codes, codebooks, None) + weights2 = ops.aqlm_dequant(codes, codebooks, parts) + + print("weights shape:", weights.shape) + print("weights2 shape:", weights2.shape) + + print("weights are:", weights) + print("weights2 are:", weights2) + + print("first 128 weights are", weights[0, 0:128].to(torch.int32)) + print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) + + print("last 128 weights are", weights[0, -128:]) + print("last 128 weights2 are:", weights2[0, -128:]) + + +def main(): + + parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") + + # Add arguments + parser.add_argument("--nbooks", + type=int, + default=1, + help="Number of codebooks (default: 1)") + parser.add_argument("--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)") + parser.add_argument( + "--test", + type=bool, + default=False, + help="Run the decompression/dequant tester rather than benchmarking " + "(default: False)") + + # Parse the arguments + args = parser.parse_args() + + # Extract values + nbooks = args.nbooks + bits = args.bits + + if args.test: + dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + return + + # Otherwise, benchmark. + methods = [ + ops.aqlm_gemm, + dequant_out_scale, + generic_dequantize_gemm, + optimized_dequantize_gemm, + dequant_weight_scale, + torch_mult, + dequant_no_scale, + ] + + filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" + print(f"writing benchmarks to file {filename}") + with open(filename, "w") as f: + sys.stdout = f + + print('m | k | n | n parts', end='') + for method in methods: + print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') + print('') + + # These are reasonable prefill sizes. + ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), + (4096, (11008, 11008)), (11008, (4096, ))) + + # reasonable ranges for m. + for m in [ + 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, + 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + ]: + print(f'{m}', file=sys.__stdout__) + for ksp in ksandpartions: + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, + methods) + + sys.stdout = sys.__stdout__ + + +def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, + methods): + + # I didn't see visible improvements from increasing these, but feel free :) + num_warmup_trials = 1 + num_trials = 1 + + num_calls = 100 + + # warmup. + for method in methods: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + n = parts.sum().item() + print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + + for method in methods: + best_time_us = 1e20 + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + + if kernel_dur_us < best_time_us: + best_time_us = kernel_dur_us + + print(f' | {kernel_dur_us:.0f}', end='') + + print('') + + +def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, + nbooks: int, bits: int, method) -> float: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + input = torch.randn((1, m, k), dtype=torch.float16, device=device) + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) + + # for comparison to just a pytorch mult. + weights = torch.randn((n, k), dtype=torch.float16, device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + if method is torch_mult: + for i in range(num_calls): + torch_mult(input, weights, scales) + else: + for i in range(num_calls): + method(input, codes, codebooks, scales, parts, None) + + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000..b771911781574 --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,233 @@ +import argparse + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, marlin_24_quantize, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def bench_run(results, model, act_order, is_k_full, num_bits, group_size, + size_m, size_k, size_n): + label = "Quant Matmul" + + sub_label = ("{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, act_order) + + # Marlin_24 quant + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + + # GPTQ quant + (w_ref, q_w, s, g_idx, + rand_perm) = quantize_weights(b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + globals = { + # Gen params + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + "a_tmp": a_tmp, + # Marlin params + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # Marlin_24 params + "marlin_24_w_ref": marlin_24_w_ref, + "marlin_24_q_w_comp": marlin_24_q_w_comp, + "marlin_24_meta": marlin_24_meta, + "marlin_24_s": marlin_24_s, + "marlin_24_workspace": marlin_24_workspace, + # GPTQ params + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + # Kernels + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_24_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results = [] + + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + if len(args.limit_act_order + ) > 0 and act_order not in args.limit_act_order: + continue + + for is_k_full in K_FULL_OPTS: + if len(args.limit_k_full + ) > 0 and is_k_full not in args.limit_k_full: + continue + + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + if len(args.limit_num_bits + ) > 0 and num_bits not in args.limit_num_bits: + continue + + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + if len( + args.limit_group_size + ) > 0 and group_size not in args.limit_group_size: + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k + or group_size == -1): + continue + + for size_m in args.batch_sizes: + bench_run(results, model, act_order, is_k_full, + num_bits, group_size, size_m, size_k, + size_n) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 +# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 8e976fbcb3028..196ec8cfce88e 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -1,3 +1,4 @@ +import argparse import json import os import sys @@ -5,68 +6,70 @@ import torch import torch.nn.functional as F import triton +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - -def main(): +def main(model, tp_size, gpu, dtype: str): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method) - - -def run_grid(bs, method): - d_model = 4096 + run_grid(bs, + model=model, + method=method, + gpu=gpu, + tp_size=tp_size, + dtype=dtype) + + +def run_grid(bs, model, method, gpu, tp_size, dtype: str): + if model == '8x7B': + d_model = 4096 + model_intermediate_size = 14336 + num_layers = 32 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + num_layers = 56 + else: + raise ValueError(f'Unsupported Mixtral model {model}') num_total_experts = 8 top_k = 2 - tp_size = 2 - model_intermediate_size = 14336 - num_layers = 32 + # tp_size = 2 num_calls = 100 num_warmup_trials = 1 num_trials = 1 configs = [] - if bs <= 16: - BLOCK_SIZES_M = [16] - elif bs <= 32: - BLOCK_SIZES_M = [16, 32] - elif bs <= 64: - BLOCK_SIZES_M = [16, 32, 64] - elif bs <= 128: - BLOCK_SIZES_M = [16, 32, 64, 128] - else: - BLOCK_SIZES_M = [16, 32, 64, 128, 256] for block_size_n in [32, 64, 128, 256]: - for block_size_m in BLOCK_SIZES_M: + for block_size_m in [16, 32, 64, 128, 256]: for block_size_k in [64, 128, 256]: for group_size_m in [1, 16, 32, 64]: for num_warps in [4, 8]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": 4, - }) + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) best_config = None best_time_us = 1e20 - for config in configs: - print(f'{tp_size=} {bs=}') - print(f'{config}') + print(f'{tp_size=} {bs=}') + + for config in tqdm(configs): # warmup - print('warming up') try: for _ in range(num_warmup_trials): run_timing( @@ -79,12 +82,12 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) except triton.runtime.autotuner.OutOfResources: continue # trial - print('benchmarking') for _ in range(num_trials): kernel_dur_ms = run_timing( num_calls=num_calls, @@ -96,6 +99,7 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) kernel_dur_us = 1000 * kernel_dur_ms @@ -105,16 +109,18 @@ def run_grid(bs, method): best_config = config best_time_us = kernel_dur_us - print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') print("best_time_us", best_time_us) print("best_config", best_config) # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size) + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): @@ -128,27 +134,48 @@ def run_grid(bs, method): def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: + config, dtype: str) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), device="cuda:0", - dtype=torch.bfloat16, + dtype=torch.float16, ) - ws = torch.rand( + w1 = torch.rand( (num_total_experts, 2 * shard_intermediate_size, d_model), device=hidden_states.device, dtype=hidden_states.dtype, ) - w2s = torch.rand( + w2 = torch.rand( (num_total_experts, d_model, shard_intermediate_size), device=hidden_states.device, dtype=hidden_states.dtype, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + gating_output = F.softmax(torch.rand( (num_calls, bs, num_total_experts), device=hidden_states.device, @@ -163,13 +190,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, for i in range(num_calls): hidden_states = method( hidden_states=hidden_states, - w1=ws, - w2=w2s, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, gating_output=gating_output[i], topk=2, renormalize=True, inplace=True, override_config=config, + use_fp8=dtype == "float8", ) end_event.record() end_event.synchronize() @@ -179,4 +211,29 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, if __name__ == "__main__": - sys.exit(main()) + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + parser.add_argument('--model', + type=str, + default='8x7B', + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + parser.add_argument('--tp-size', + type=int, + default=2, + help='Tensor paralleli size') + parser.add_argument('--gpu', + type=int, + default=0, + help="GPU ID for benchmarking") + args = parser.parse_args() + sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 24f734ce8cce4..0fcfc0a295ca2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ import torch -from vllm._C import ops +from vllm import _custom_ops as ops from vllm._custom_C import paged_attention_custom from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random @@ -17,7 +17,7 @@ def main( version: str, num_seqs: int, - context_len: int, + seq_len: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -49,12 +49,12 @@ def main( dtype=torch.float, device=device) - context_lens = [context_len for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -81,7 +81,7 @@ def main( if not args.custom_paged_attn: global PARTITION_SIZE PARTITION_SIZE = 512 - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -114,9 +114,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -134,9 +134,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -153,9 +153,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, ) @@ -189,12 +189,12 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") @@ -207,13 +207,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") parser.add_argument("--custom-paged-attn", action="store_true", help="Use custom paged attention") @@ -225,7 +223,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: main( version=args.version, num_seqs=args.batch_size, - context_len=args.context_len, + seq_len=args.seq_len, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 9188e811e2982..00e55f6060b52 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -93,7 +93,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000..4eeeca35a37cc --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,75 @@ +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index 64d3c4f4b3889..f491c90d0683e 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,7 +4,7 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run --gpus all --shm-size 1g -p $PORT:80 \ +docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 0000000000000..c846e47de1fcf --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,63 @@ +import argparse +import cProfile +import pstats + +from vllm import LLM, SamplingParams + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?" + ] * 1000 +LONG_PROMPT = ' '.join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + use_v2_block_manager=args.use_v2_block_manager, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', + globals(), locals()) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats('cumulative') + total_time = 0 + total_calls = 0 + for func in stats.stats: + if 'hash_of_block' in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print(f"Hashing took {total_time:.2f} seconds," + f"{percentage:.2f}% of the total runtime.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark the performance of hashing function in' + 'automatic prefix caching.') + parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') + args = parser.parse_args() + main(args) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 8339fa6cccb0d..b173ee106d562 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/collect_env.py b/collect_env.py index 8982fba024274..1ecfeb8e22e2f 100644 --- a/collect_env.py +++ b/collect_env.py @@ -63,6 +63,7 @@ "magma", "triton", "optree", + "nccl", } DEFAULT_PIP_PATTERNS = { @@ -73,6 +74,7 @@ "triton", "optree", "onnx", + "nccl", } diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24d972702c858..867f63f12de4b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float) x; + const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; - return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float) x; + const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T) (0.5f * f * (1.0f + ::tanhf(inner))); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "act_and_mul_kernel", \ - [&] { \ - vllm::act_and_mul_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); - -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -96,11 +90,11 @@ void gelu_tanh_and_mul( namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -108,54 +102,49 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } -} // namespace vllm +} // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 31fb401cbe2c1..62409c0cce93e 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b114ab0cfdb57..ece7e749e7312 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -19,27 +20,23 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM -#define WARP_SIZE 32 + #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize + #define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -49,7 +46,7 @@ namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -86,58 +83,65 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -147,19 +151,18 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -169,18 +172,21 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -199,58 +205,101 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; + const bool mask = token_idx >= seq_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -298,13 +347,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -312,14 +360,13 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -330,48 +377,55 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); } - if (block_idx == num_context_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -390,8 +444,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -428,9 +482,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -442,89 +496,96 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride, kv_scale); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -543,8 +604,9 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -573,9 +635,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -588,65 +652,56 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ - vllm::paged_attention_v1_kernel<<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + kv_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, +template + int NUM_THREADS = 1024> #else - int NUM_THREADS = 128> + int NUM_THREADS = 128> #endif void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -659,20 +714,22 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -701,6 +758,9 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V1(256); break; @@ -710,133 +770,97 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - context_lens_ptr, \ - max_num_partitions); - -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template + int NUM_THREADS = 1024, int PARTITION_SIZE = 1024> #else - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> + int NUM_THREADS = 128, int PARTITION_SIZE = 512> #endif void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -849,9 +873,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -861,10 +886,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -897,6 +922,9 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; @@ -906,84 +934,68 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 22273c11d483e..826b0edffae67 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,7 +27,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..3cdcb95e08099 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d3271e69cd69d..3a1815f0ed4fc 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb0..7c6a686db3ba9 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..e714e321b0beb 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,33 +3,39 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 -#include -#endif +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache -template<> +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template<> +template <> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template<> +template <> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template<> +template <> struct Vec { - using Type = uint2; + using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index fa26ddb688588..064815b7403db 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -5,21 +5,20 @@ #include #include -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const std::map& block_mapping); +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const std::map>& block_mapping); +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping); -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const float kv_scale); \ No newline at end of file +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const float kv_scale); + +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1638b2aeee77e..2ab63b21db1fb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,10 +4,11 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -17,20 +18,17 @@ #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const std::map& block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -40,41 +38,44 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - int64_t dst_block_number = pair.second; + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -92,12 +93,11 @@ __global__ void copy_blocks_kernel( } } -} // namespace vllm +} // namespace vllm -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const std::map>& block_mapping) { +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -111,29 +111,23 @@ void copy_blocks( int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); - } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -142,31 +136,28 @@ void copy_blocks( const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), - numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -187,60 +178,84 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, kv_scale); } } } -} // namespace vllm - -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ - kv_scale); +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + // head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; + } +} +} // namespace vllm + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const float kv_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -254,23 +269,43 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) +} + +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { + // FIXME: only support auto datatype, does not support fp8 + if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } -} + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = k_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_flash", [&] { + vllm::reshape_and_cache_flash_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + slot_mapping.data_ptr(), block_stride, key_stride, + value_stride, num_heads, head_size, block_size); + }); +} \ No newline at end of file diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 1bd24eb79d129..becd2ac42f17a 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, - scalar_t *__restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] }); } -void gelu_new(torch::Tensor &out, torch::Tensor &input) { +void gelu_new(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { }); } -void gelu_fast(torch::Tensor &out, torch::Tensor &input) { +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 365bbd5e23728..ed8cfbd421f0f 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,8 @@ namespace { -template struct KernelVecType { +template +struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -11,7 +12,8 @@ template struct KernelVecType { using v_load_vec_type = void; }; -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -21,7 +23,8 @@ template <> struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -30,7 +33,8 @@ template <> struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -41,7 +45,7 @@ template <> struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -67,14 +71,15 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, } template -FORCE_INLINE std::pair -reduceSoftmaxAlibi(T *data, const int size, const int capacity, - const float alibi_slope, const int start_index, - const int context_len) { - data[0] += alibi_slope * (start_index - context_len + 1); +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, } template -FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -132,9 +137,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t *__restrict__ q, - const scalar_t *__restrict__ k_block, - float *__restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -196,8 +201,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, - acc_t &&acc) { +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -209,66 +214,65 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void - call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; static_assert(BLOCK_SIZE == 16); - int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_context_len_padded * sizeof(float); - float *logits = (float *)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_context_len_padded] + parallel_work_item_num * max_seq_len_padded * sizeof(float); + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int context_len = context_lens[seq_idx]; - const int *seq_block_table = + int seq_len = seq_lens[seq_idx]; + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = - context_len - (block_num - 1) * BLOCK_SIZE; - float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_context_len_padded; + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -278,12 +282,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, context_len, + reduceSoftmaxAlibi(thread_block_logits, seq_len, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - context_len); + seq_len); } else { - reduceSoftmax(thread_block_logits, context_len, - block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } // Compute value @@ -293,14 +296,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -311,7 +314,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,16 +343,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -359,68 +362,73 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - context_lens, max_context_len, alibi_slopes); - -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes); + +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, - torch::Tensor &key_cache, torch::Tensor &value_cache, - int num_kv_heads, float scale, - torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { +} // namespace + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -434,23 +442,24 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float - *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -465,27 +474,25 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= context_len) - continue; + if (start_token_idx >= seq_len) continue; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); - const int context_token_num = - (std::min(context_len, start_token_idx + PARTITION_SIZE) - + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = - (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = - context_token_num - (block_num - 1) * BLOCK_SIZE; - const int *seq_block_table = block_tables + + token_num - (block_num - 1) * BLOCK_SIZE; + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -493,10 +500,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -507,16 +514,16 @@ struct paged_attention_v2_impl { std::pair max_and_sum; if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( - logits, context_token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, context_len); + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, context_token_num, - block_num * BLOCK_SIZE); + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } - auto &&[max_logit, exp_sum] = max_and_sum; + auto&& [max_logit, exp_sum] = max_and_sum; - scalar_t *__restrict__ output_buffer = nullptr; + scalar_t* __restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -538,13 +545,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -555,7 +562,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -583,12 +590,11 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -603,30 +609,29 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE - // didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float *__restrict__ rescale_factors = exp_sums; + const float* __restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; - const float *__restrict__ seq_head_rescale_factors = + const float* __restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t *__restrict__ seq_head_tmp_out = + const scalar_t* __restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t *__restrict__ seq_head_output = + scalar_t* __restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -645,21 +650,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -670,73 +675,78 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, block_size, \ - max_context_len, alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, - torch::Tensor &max_logits, torch::Tensor &tmp_out, - torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, - float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { +} // namespace + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 7849a5df991b1..2890ba6e2bb32 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,25 +5,26 @@ namespace { template -void copy_blocks_cpu_impl( - std::vector &key_caches, - std::vector &value_caches, - const std::vector> mapping_pairs, - const int element_num_per_block, const int layer_num) { - const size_t pair_num = mapping_pairs.size(); +void copy_blocks_cpu_impl(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { + const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = - element_num_per_block * mapping_pairs[pair].second; - scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t *source_ptr = key_cache_ptr + source_offset; - scalar_t *target_ptr = key_cache_ptr + target_offset; + element_num_per_block * mapping_pairs[pair][1].item(); + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( template void reshape_and_cache_cpu_impl( - const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, - scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t *src_key_head_ptr = key + src_key_head_idx; - const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t *target_key_head_ptr = key_cache + + scalar_t* target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t *target_value_head_ptr = value_cache + + scalar_t* target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -79,39 +80,31 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -void copy_blocks(std::vector &key_caches, - std::vector &value_caches, - const std::map> &block_mapping) { - int num_layers = key_caches.size(); +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping) { + unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } - std::vector> mapping_pairs; - mapping_pairs.reserve(block_mapping.size()); - for (const auto &pair : block_mapping) { - for (const auto &dst : pair.second) { - mapping_pairs.emplace_back(pair.first, dst); - } - } - const int element_num_per_block = key_caches[0][0].numel(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) }); } -void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, - torch::Tensor &key_cache, torch::Tensor &value_cache, - torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype, float kv_scale) { +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -135,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, }); } -void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const std::map &block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 467f0dc84982c..65d3ddcec5709 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t *__restrict__ out, - const scalar_t *__restrict__ input, - const scalar_t *__restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t *__restrict__ input, - scalar_t *__restrict__ residual, - const scalar_t *__restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor &out, torch::Tensor &input, - torch::Tensor &weight, float epsilon) { +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon) { +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e9b3992204bb2..73bf77e46f538 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,22 +4,21 @@ namespace { template void rotary_embedding_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); @@ -27,7 +26,7 @@ void rotary_embedding_impl( #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; for (int i = 0; i < num_heads; ++i) { const int head_idx = i; @@ -95,16 +94,16 @@ void rotary_embedding_impl( template void rotary_embedding_gptj_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t *head_query = token_head + query; + scalar_t* head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t *head_key = key + token_head; + scalar_t* head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -165,11 +164,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index bba044087f37c..63082393c8102 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); } diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index fc3e217aef7a1..82e55613d915a 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM -#include + #include #endif #ifndef USE_ROCM @@ -17,11 +17,14 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM @@ -30,6 +33,13 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) @@ -37,4 +47,3 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 1483484faeb4a..2ba49b339e148 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,9 +2,6 @@ #include -int get_device_attribute( - int attribute, - int device_id); +int get_device_attribute(int attribute, int device_id); -int get_max_shared_memory_per_block_device_attribute( - int device_id); +int get_max_shared_memory_per_block_device_attribute(int device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 1a443ef3620cc..7d8e2e19720fa 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,34 +2,28 @@ #include #include #endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; +int get_device_attribute(int attribute, int device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; } - -int get_max_shared_memory_per_block_device_attribute( - int device_id) -{ -int attribute; -// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html -// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 +int get_max_shared_memory_per_block_device_attribute(int device_id) { + int attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 3906dcfc80dbf..0b1d95848525a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -7,11 +7,11 @@ // fake pointer type using fptr_t = uint64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) { +bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } int meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 750e68d42f6c6..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal *signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half &assign_add(half &a, half b) { +DINLINE half& assign_add(half& a, half b) { a = __hadd(a, b); return a; } -DINLINE float &assign_add(float &a, float b) { return a += b; } +DINLINE float& assign_add(float& a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); } @@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P *ptrs[], int idx) { +DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = - packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Signal *)sg) + 1); +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + ptrs[i] = (const P*)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + ((P*)result)[dst_idx] = tmps[i][idx]; } } } @@ -261,14 +258,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal *self_sg_; + std::unordered_map buffers_; + Signal* self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -279,22 +276,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal *rank_sg; + Signal* rank_sg; if (i != rank_) { - char *handle = open_ipc_handle(&handles[i]); + char* handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal *)handle; + rank_sg = (Signal*)handle; } else { rank_sg = self_sg_; } @@ -302,13 +299,13 @@ class CustomAllreduce { } } - char *open_ipc_handle(const void *ipc_handle) { + char* open_ipc_handle(const void* ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { - char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -323,7 +320,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void *base_ptr; + void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -331,8 +328,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return std::make_pair(handles, offsets); } @@ -344,13 +341,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector &handles, - const std::vector &offsets, void *self) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle = open_ipc_handle(handles[i].data()); + char* handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -371,17 +368,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector &handles, - const std::vector> &offsets) { + const std::vector& handles, + const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto &rd = rank_data[i]; + auto& rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle = + char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -405,7 +402,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -418,7 +415,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData *ptrs; + RankData* ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c34a50389c21c..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T *data, int size, int myRank) { +__global__ void set_data(T* data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { } template -__global__ void convert_data(const T *data1, const T *data2, double *fdata1, - double *fdata2, int size) { +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, } } -__global__ void init_rand(curandState_t *state, int size, int nRanks) { +__global__ void init_rand(curandState_t* state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, int data_size, bool performance_test) { - T *result; + T* result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal *buffer; - T *self_data_copy; + vllm::Signal* buffer; + T* self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void *rank_data; + void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto *self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char *begin = (char *)&data_handles[i]; - char *end = (char *)&data_handles[i + 1]; + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *ground_truth; + double* ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t *states; + curandState_t* states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char **argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..3ecea03242f06 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,32 +6,30 @@ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index e56b4d2204005..70a2b3b0a07b1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -11,26 +11,24 @@ #include #include - using __nv_bfloat16 = __hip_bfloat16; - using __nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } - /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { static constexpr bool exists = false; }; +template +struct _typeConvert { + static constexpr bool exists = false; +}; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp += T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp *= T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i+1]}); + float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t< - (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -203,9 +214,12 @@ __global__ std::enable_if_t< /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -215,10 +229,11 @@ __global__ std::enable_if_t< residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -233,52 +248,50 @@ __global__ std::enable_if_t< } } - /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t< - (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float) z; + float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -286,40 +299,27 @@ void rms_norm( dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "fused_add_rms_norm_kernel", \ - [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), \ - epsilon, \ - num_tokens, \ - hidden_size); \ - }); - -void fused_add_rms_norm( - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm( auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ - && wt_ptr % 16 == 0; + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 35c328499a22d..4122f7630d7c7 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -3,5 +3,6 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); + m.def("topk_softmax", &topk_softmax, + "Apply topk softmax to the gating outputs."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d72..93e7844ac1993 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -2,8 +2,6 @@ #include -void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index e01b23685ef4e..edc441d121029 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,119 +7,128 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; } +} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, - size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are assigned - * to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding blocks - * and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { - expert_ids[i / block_size] = threadIdx.x; +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } - - /** - * Each thread processes a token shard, calculating the index of each token after - * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and - * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], - * where * represents a padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] - * stores the indices of the tokens processed by the expert with expert_id within - * the current thread's token shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } } - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors - const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK( - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - block_size, + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 3acbc1b3f8363..d6cdfab434f2c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,168 +3,119 @@ #include void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -void batched_rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_tanh_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); -#ifndef USE_ROCM -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy); - -torch::Tensor marlin_gemm( - torch::Tensor& a, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, - int64_t size_k); -#endif +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit); - -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit); - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); #ifndef USE_ROCM -using fptr_t = uint64_t; -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, - bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out); -void dispose(fptr_t _fa); -int meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + #endif -void convert_fp8( - torch::Tensor& src_cache, - torch::Tensor& dst_cache, - torch::Tensor& scale); +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + float scale); + +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale); + +#ifdef USE_ROCM torch::Tensor fp8_gemm( torch::Tensor& a, torch::Tensor& b, @@ -180,4 +131,32 @@ torch::Tensor fp8_gemm_16( torch::Tensor& scaleA, torch::Tensor& scaleB, int algo_idx -); \ No newline at end of file +); +#endif + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); + +#ifndef USE_ROCM +using fptr_t = uint64_t; +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, + bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); +void dispose(fptr_t _fa); +int meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); +#endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index d80cb6973fad6..69d6dae1c26bc 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,14 +7,10 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* cache_ptr, - const int head_size, - const int num_heads, - const int num_kv_heads, - const int rot_dim, - const int token_idx, - const int64_t query_stride, - const int64_t key_stride) -{ + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -135,36 +141,21 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); } /* @@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -191,36 +183,21 @@ void batched_rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); } diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu index c642e94925fe5..86846c274c90f 100644 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu deleted file mode 100644 index e8202dff561d9..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu deleted file mode 100644 index 3e7cf31dead0f..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu deleted file mode 100644 index 68277fa6b7d56..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu index 0607cebfeac40..de39c3121f5d3 100644 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu deleted file mode 100644 index 3b7531b8fbcfc..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2219d960ae62f..4b376261d30d2 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ @@ -27,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 4096) \ @@ -35,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ @@ -46,11 +49,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ + f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ @@ -58,9 +63,91 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 32768) \ f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 64000) \ + f(in_T, out_T, W_T, narrow, 64256) \ + f(in_T, out_T, W_T, narrow, 64512) \ + f(in_T, out_T, W_T, narrow, 102400) \ + f(in_T, out_T, W_T, narrow, 102656) \ + f(in_T, out_T, W_T, narrow, 102912) \ + f(in_T, out_T, W_T, narrow, 128000) \ + f(in_T, out_T, W_T, narrow, 128256) \ + f(in_T, out_T, W_T, narrow, 128512) \ +// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA +// and vllm/tests/lora/test_punica.py + +// Used for defining kernels going from the variety of +// dim in to the narrow dim out + // Using it for the fully sharded column + // parallel LoRA A which splits the rank dim +#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, 128, narrow) \ + f(in_T, out_T, W_T, 256, narrow) \ + f(in_T, out_T, W_T, 512, narrow) \ + f(in_T, out_T, W_T, 640, narrow) \ + f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 1024, narrow) \ + f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1280, narrow) \ + f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1728, narrow) \ + f(in_T, out_T, W_T, 1792, narrow) \ + f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2560, narrow) \ + f(in_T, out_T, W_T, 2752, narrow) \ + f(in_T, out_T, W_T, 2816, narrow) \ + f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3328, narrow) \ + f(in_T, out_T, W_T, 3456, narrow) \ + f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 5120, narrow) \ + f(in_T, out_T, W_T, 5504, narrow) \ + f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6400, narrow) \ + f(in_T, out_T, W_T, 6848, narrow) \ + f(in_T, out_T, W_T, 6912, narrow) \ + f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 10240, narrow) \ + f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 12288, narrow) \ + f(in_T, out_T, W_T, 13696, narrow) \ + f(in_T, out_T, W_T, 13824, narrow) \ + f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 15360, narrow) \ + f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 20480, narrow) \ + f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 24576, narrow) \ + f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 27648, narrow) \ + f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 32000, narrow) \ + f(in_T, out_T, W_T, 32256, narrow) \ + f(in_T, out_T, W_T, 32512, narrow) \ + f(in_T, out_T, W_T, 32768, narrow) \ + f(in_T, out_T, W_T, 33024, narrow) \ + f(in_T, out_T, W_T, 36864, narrow) \ + f(in_T, out_T, W_T, 43264, narrow) \ + f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 64000, narrow) \ + f(in_T, out_T, W_T, 64256, narrow) \ + f(in_T, out_T, W_T, 64512, narrow) \ + f(in_T, out_T, W_T, 102400, narrow) \ + f(in_T, out_T, W_T, 102656, narrow) \ + f(in_T, out_T, W_T, 102912, narrow) \ + f(in_T, out_T, W_T, 128000, narrow) \ + f(in_T, out_T, W_T, 128256, narrow) \ + f(in_T, out_T, W_T, 128512, narrow) \ // Keep above in sync with vllm/lora/layers::SamplerWithLoRA + // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ @@ -68,4 +155,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ + f(in_T, out_T, W_T, 8, 64) \ + f(in_T, out_T, W_T, 16, 64) \ + f(in_T, out_T, W_T, 32, 64) \ + f(in_T, out_T, W_T, 64, 64) + // clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu deleted file mode 100644 index b3b74aa3ec904..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu deleted file mode 100644 index 3cc87f5df76a1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu deleted file mode 100644 index 9eda98bd8ddcf..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu index f1db6df5f7338..d225a1eaa82b0 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu deleted file mode 100644 index 060f9ebb8c2b1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu index c01ddd009d74e..b37d288a75561 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu index f45183ffd3486..a1ab2deecbabf 100644 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu deleted file mode 100644 index b37e44570bf40..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu deleted file mode 100644 index 06718cbb0a3e9..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu index 4097743488087..0b35bf5699898 100644 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu deleted file mode 100644 index 41fb0e45ef4e6..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu deleted file mode 100644 index 50b7ead9fcefd..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 995de26e8bada..8a3b8403b4a6f 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,24 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +#ifndef USE_ROCM + // nthrs = (32, 4) template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + size_t j = blockIdx.x; + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + __shared__ float y_warpwise[ty]; + + float y = 0; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + x_vec.load(X + (batch_idx * feat_in) + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + + __syncthreads(); + + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + y += sum; + } + } + + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = y; + } + __syncthreads(); + + float y_write = 0.f; +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y_write += y_warpwise[i]; + } + + // write Y; + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); + } +} + +#endif + // nthrs = (2, 16, 4) template @@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } @@ -199,7 +308,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, constexpr int tz = 4; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if constexpr (feat_in < feat_out) { + if constexpr (feat_in <= feat_out) { static_assert(feat_in % vec_size == 0); constexpr int tx = feat_in / vec_size; @@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { +#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } +#else + constexpr size_t rocm_warp_size = warpSize; + +#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ + feat_in % (rocm_warp_size * vec_size_) == 0 + +#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ + if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ + constexpr size_t vec_size_shrink = vec_size_; \ + constexpr int tx = tx_; \ + constexpr int ty = ty_; \ + dim3 nblks(feat_out, batch_size); \ + dim3 nthrs(tx, ty); \ + bgmv_shrink_kernel \ + <<>>(Y, X, W, indicies, y_offset, \ + full_y_size, num_layers, layer_idx, \ + scale); \ + } + + static_assert(CHECK_INPUT_TILEABLE_BY(32) || + CHECK_INPUT_TILEABLE_BY(16) || + CHECK_INPUT_TILEABLE_BY( 8) || + CHECK_INPUT_TILEABLE_BY( 4) || + CHECK_INPUT_TILEABLE_BY( 2) || + CHECK_INPUT_TILEABLE_BY( 1)); + + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) + +#undef CHECK_INPUT_TILEABLE_BY +#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM +#endif } } @@ -289,6 +443,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t num_layers, int64_t layer_idx, float scale); +#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ + INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) + #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index c347d4f2ab9f4..972df5a7208c2 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -10,6 +10,7 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) """.lstrip() # noqa: E501 for input_dtype in DTYPES: @@ -18,6 +19,26 @@ if weight_dtype == "fp32": # FP32 weights are not supported. continue + if output_dtype == "fp32": + # LoRA A matrix. + if input_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # input and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif input_dtype == "fp32": + # LoRA B matrix. + if output_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # output and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif not (input_dtype == output_dtype == weight_dtype): + # NOTE(woosuk): While Punica supports mixed data types for + # input, output, and weight, we only generate the kernels with + # the same data types to reduce the binary size. + continue + kernel_definition = TEMPLATE.format( input_dtype=DTYPE_MAP[input_dtype], output_dtype=DTYPE_MAP[output_dtype], diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#include -#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +8,9 @@ #include +#include "../type_convert.h" +#include "../../cuda_compat.h" + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 95% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 28739be14b862..61de3b37937cc 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,11 @@ -#include -#include #include #include #include +#include "type_convert.h" +#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" -namespace { //====== utils ====== @@ -20,8 +19,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -46,13 +45,30 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { template inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, const int64_t *lora_indices, - uint16_t in_features, uint16_t out_features, + uint32_t in_features, uint32_t out_features, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { + // NOTE(woosuk): While Punica supports various combinations of input/output + // data types, we limit the supported data types to reduce the binary size. + constexpr bool is_input_float = std::is_same::value; + constexpr bool is_output_float = std::is_same::value; + if (is_input_float) { + if (!std::is_same::value) { + return false; + } + } else if (is_output_float) { + if (!std::is_same::value) { + return false; + } + } else if (!(std::is_same::value && + std::is_same::value)) { + return false; + } + + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ + case pack_u32(feat_in, feat_out): \ bgmv_kernel(Y, X, W, lora_indices, y_offset, \ full_y_size, batch_size, num_layers, \ layer_idx, scale); \ @@ -62,12 +78,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) FOR_BGMV_WIDE_NARROW(CASE, _, _, _) + FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) #undef CASE #undef CASE_ONESIDE default: return false; } - return true; } @@ -93,7 +109,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -325,7 +341,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -551,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..9490ad59cdd5f --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index d533b6fcca498..a4693ccc2ae75 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -8,109 +8,95 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - ops.def( - "batched_rotary_embedding", - &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); - -// FP8 - ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value tensors to or from fp8 data type"); - ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM"); - - ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM"); + ops.def("batched_rotary_embedding", &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " + "(supports multiple loras)"); // Quantization ops #ifndef USE_ROCM + ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); + ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, + "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, + "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, + "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, + "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " + "per-row/column quantization."); #endif - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, + "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, + "Compute FP8 quantized tensor and scaling factor"); + ops.def("moe_align_block_size", &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such " + "that it is divisible by the block size."); + ops.def("convert_fp8", &convert_fp8, + "Convert the key and value cache to fp8 data type"); + +#ifdef USE_ROCM + ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output"); + ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output"); +#endif + + ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, + "Compute int8 quantized tensor for given scaling factor"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); + pybind11::module cuda_utils = + m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def("get_device_attribute", &get_device_attribute, + "Gets the specified device attribute."); - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute, + "Gets the maximum shared memory per block device attribute."); #ifndef USE_ROCM // Custom all-reduce kernels @@ -127,5 +113,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); #endif - } diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu new file mode 100644 index 0000000000000..255844eec56d4 --- /dev/null +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -0,0 +1,598 @@ +/* + * Modified by Neural Magic + * Adapted from https://github.com/Vahe1994/AQLM + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace vllm { +namespace aqlm { + +__global__ void Code1x16MatVec( + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + __shared__ int4 sh_b[32 * 9]; + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[4]; + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code2x8MatVec( + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. + +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + extern __shared__ int4 sh[]; + int4* sh_b = sh; + int4* sh_code = sh_b + 32 * 9; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +#pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code1x16Dequant( + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + auto dec = reinterpret_cast(&chunk); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); + + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +__global__ void Code2x8Dequant( + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + extern __shared__ int4 sh[]; + int4* sh_code = sh; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } + +const int THREAD_M = 16; + +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); +} + +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code2x8MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); +} + +void code1x16_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. + ); +} + +// Dequantizes the code and codebook into weights. +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + Code2x8Dequant<<>>( + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); +} + +int codebook_stride(const torch::Tensor& codebooks) { + return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); +} + +void code1x16_matvec( + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); +} + +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); +} + +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +// Accumulate the partition sizes. +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; + int i = 0; + int last = 0; + assert(codebook_partition_sizes.size(0) <= 4); + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { + *cumulative_size = codebook_partition_sizes[i].item() + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; + } + return cumulative_sizes; +} + +} // namespace aqlm +} // namespace vllm + +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); + } + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") + return {}; +} + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); + int rows = codes.size(1); + int cols = codes.size(0); + + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + + assert(out_features = codebook_partition_sizes.sum().item()); + + auto weights = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; + } + + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") + return {}; +} diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index d1d926de18d78..813ec6716cf54 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ @@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16aef..bb8e5bbb23d7f 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,14 +1,12 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - #include #include @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -386,61 +489,61 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..4902e4c23434c --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), scale, + hidden_size); + }); +} diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000000000..999b7b251ab33 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "cutlass/cutlass.h" + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } diff --git a/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp new file mode 100644 index 0000000000000..ddbee15e54ab6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass It's beem modified to support either +// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x. +// Important because this saves us a factor 4x on the number of kernels +// compiled. +// +#pragma once + +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +// clang-format on + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if(get<1>(coord_v(i)) < n) + { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->ptr_col) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if(pred(i)){ + dst_v(i) = params_ptr->null_default; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu new file mode 100644 index 0000000000000..3a6b8a226e18c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -0,0 +1,300 @@ +#include +#include + +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "cutlass_visitor_2x_broadcast_epilogue.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::threadblock::Sm80EVT; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, + Stride, Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, + Arch, + TileShape, WarpShape, InstructionShape, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + MainLoopStages, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + // clang-format on + + using Op = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto a_scales_ptr = a_scales.data_ptr(); + auto b_scales_ptr = b_scales.data_ptr(); + + // If A and B are quantized per-tensor, then these scale tensors are scalars, + // and they are passed in via the second argument. + using ScaleAArgs = typename Gemm::ScaleA::Arguments; + ScaleAArgs a_args = a_scales.numel() == 1 + ? ScaleAArgs{nullptr, a_scales.item(), {}} + : ScaleAArgs{a_scales.data_ptr(), {}, {}}; + + using ScaleBArgs = typename Gemm::ScaleB::Arguments; + ScaleBArgs b_args = b_scales.numel() == 1 + ? ScaleBArgs{nullptr, b_scales.item(), {}} + : ScaleBArgs{b_scales.data_ptr(), {}, {}}; + + typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, + evt0_compute_args}; + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + typename Gemm::EVTD::Arguments epilogue_args{ + evt1_compute_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get(), stream); + CUTLASS_CHECK(status); +} + +} // namespace + +void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu new file mode 100644 index 0000000000000..531414bc45165 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -0,0 +1,249 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + NVIDIA GPUs with sm90a (Hopper) or later. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<0>, Int<0>>>; + + using ScaleBDescriptor = + cutlass::epilogue::collective::detail::RowBroadcastDescriptor< + EpilogueDescriptor, float>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute1>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + using ScaleA_Args = typename Gemm::ScaleA::Arguments; + using ScaleB_Args = typename Gemm::ScaleB::Arguments; + ScaleA_Args a_args = a_scales.numel() == 1 + ? ScaleA_Args{nullptr, a_scales.item(), {}} + : ScaleA_Args{a_scales.data_ptr(), {}, {}}; + + ScaleB_Args b_args = b_scales.numel() == 1 + ? ScaleB_Args{nullptr, b_scales.item(), {}} + : ScaleB_Args{b_scales.data_ptr(), {}, {}}; + + args.epilogue.thread = {a_args, {b_args}}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + TORCH_CHECK(workspace_size == 0); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + cutlass::Status status = gemm_op.run(args, stream); + CUTLASS_CHECK(status); +} +} // namespace + +void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } +} + +#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu new file mode 100644 index 0000000000000..eb532f2ac7a9b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -0,0 +1,75 @@ +#include + +#include +#include + +void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); +#endif + +void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + int32_t major_capability; + int32_t minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + + if (version_num >= 90) { + // Hopper + + // Guard against compilation issues for sm90 kernels +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); +#else + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); +#endif + } else if (version_num == 89) { + // Ada Lovelace + cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + } else if (version_num >= 80) { + // Ampere + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + } else { + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/fp8/gemm_kernel.cu b/csrc/quantization/fp8/amd/gemm_kernel.cu similarity index 99% rename from csrc/quantization/fp8/gemm_kernel.cu rename to csrc/quantization/fp8/amd/gemm_kernel.cu index 558228db2c084..5464e9381e343 100644 --- a/csrc/quantization/fp8/gemm_kernel.cu +++ b/csrc/quantization/fp8/amd/gemm_kernel.cu @@ -266,4 +266,4 @@ torch::Tensor fp8_gemm_16( // hipFree(d_scaleB); return result; -} +} \ No newline at end of file diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h new file mode 100644 index 0000000000000..f9c80fcdec576 --- /dev/null +++ b/csrc/quantization/fp8/amd/hip_float8.h @@ -0,0 +1,137 @@ +#pragma once + +#ifdef __HIPCC__ + #include +#else + #include + #include + #include + #include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } +}; + +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); +} diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h new file mode 100644 index 0000000000000..90251c3539534 --- /dev/null +++ b/csrc/quantization/fp8/amd/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ +#else + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl { + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh similarity index 57% rename from csrc/quantization/fp8/amd_detail/quant_utils.cuh rename to csrc/quantization/fp8/amd/quant_utils.cuh index 9efe735cb0835..23d975fe0f37e 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,304 +5,297 @@ #include #include -#include "../../../attention/dtype_float32.cuh" -#include "../../../attention/dtype_bfloat16.cuh" +#include "../../../attention/attention_dtypes.h" + +namespace vllm { +#ifdef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 -namespace vllm -{ -namespace fp8_e4m3 { template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; } template -__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) -{ - return x; +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; } // fp8 -> half template <> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); } // fp8x2 -> float2 template <> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; -#else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; -#endif +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ - hip_fp8 res{__bfloat162float(a)}; - return res.data; +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } // float2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; - float16 = __float22half2_rn(a); - return uint32; + float16 = __float22half2_rn(a); + return uint32; } // Float4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - return b; + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; } // Float4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; } // Float8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; } // float2 -> bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) -{ - __nv_bfloat162 b = __float22bfloat162_rn(a); - return b; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; } // Float4 -> bfloat162x2 template <> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) -{ - bf16_4_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - return b; +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; } // Float8 -> bfloat162x4 template <> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) -{ - bf16_8_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - b.z = __float22bfloat162_rn(a.z); - b.w = __float22bfloat162_rn(a.w); - return b; +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; } - /* Scaled and vectorized conversions, for data exchange between high and low precision domains @@ -653,6 +646,60 @@ __inline__ __device__ uint32_t scaled_vec_conversion(const flo tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); return tmp.ui32; } + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } + #endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } + #endif + assert(false); +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } -} -} // namespace vllm +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd_detail/hip_float8.h deleted file mode 100644 index 87c7c9ce66100..0000000000000 --- a/csrc/quantization/fp8/amd_detail/hip_float8.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#ifdef __HIPCC__ -#include -#else -#include -#include -#include -#include -#endif - -#include "hip_float8_impl.h" - -struct alignas(1) hip_fp8 -{ - struct from_bits_t - { - }; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) - { - } - -#ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) - { - } - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) - { - } - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) - { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) - { - } - -#ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const - { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); - } -}; - -namespace std -{ -inline hip_fp8 sin(hip_fp8 a) -{ - return hip_fp8(sinf(float(a))); -} -inline hip_fp8 cos(hip_fp8 a) -{ - return hip_fp8(cosf(float(a))); -} -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) -{ - return a; -} -} // namespace std - -// Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) -{ - return os << float(f8); -} - -// all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) -{ - return (fa + float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) -{ - return (float(a) + fb); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) -{ - return hip_fp8(float(a) + float(b)); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) -{ - return a = hip_fp8(float(a) + float(b)); -} - -// overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) -{ - return float(a) * float(b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) -{ - return (a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) -{ - return (float(a) * b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) -{ - return ((float)a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) -{ - return ((float)a * float(b)); -} - -// overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) -{ - return (a.data == b.data); -} -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) -{ - return (a.data != b.data); -} - -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) >= static_cast(b); -} -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) > static_cast(b); -} diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h deleted file mode 100644 index e05905b4e49e8..0000000000000 --- a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h +++ /dev/null @@ -1,316 +0,0 @@ -#pragma once - -#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) -#define __HIP__MI300__ -#endif - -#ifdef __HIPCC__ -#define HIP_FP8_HOST_DEVICE __host__ __device__ -#define HIP_FP8_HOST __host__ -#define HIP_FP8_DEVICE __device__ -#else -#define HIP_FP8_HOST_DEVICE -#define HIP_FP8_HOST -#define HIP_FP8_DEVICE -#endif - -namespace hip_fp8_impl -{ - -#ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) -{ - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; -} -#endif // __HIP__MI300__ - -HIP_FP8_HOST inline int clz(uint32_t x) -{ - return __builtin_clz(x); -} -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) -{ - return __clz(x); -} -#endif - -template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) -{ -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; - if (sizeof(T) == 4) { - x = reinterpret_cast(_x); - } else { - x = reinterpret_cast(_x); - } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - - if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; - } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; - } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we -mostly concern fp16 here. In this case, f8 is usually in denormal. But there -could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has -exponent bias 16. It means that there are some numbers in fp16 denormal but they -are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 -(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal - range. For example fp8 nanoo mode, denormal exponent is -7, but if the - fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no difference - // for this case, - // act_exponent could be larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa - } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. -*/ - - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; - } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that - // is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } - } - - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } - - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); - } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; -} - -template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) -{ -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); - - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - - T fInf, fNegInf, fNaN, fNeg0; - -#ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else -#endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); - } - - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); - } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; - } - return reinterpret_cast(retval); -} - -} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu new file mode 100644 index 0000000000000..937df5a0bec13 --- /dev/null +++ b/csrc/quantization/fp8/common.cu @@ -0,0 +1,211 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#ifdef USE_ROCM + #include "amd/quant_utils.cuh" +#else + #include "nvidia/quant_utils.cuh" +#endif + +namespace vllm { + +template +__global__ void convert_fp8_kernel( + const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) +{ + const int64_t block_idx = blockIdx.x; + + using V_in_vec = typename Vec::Type; + using V_out_vec = typename Vec::Type; + auto dst_data_vec = reinterpret_cast(dst_data); + auto src_data_vec = reinterpret_cast(src_data); + + int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); + auto idx = startIdx; + if (idx >= N) { + return; + } + dst_data_vec[idx] = fp8::scaled_vec_conversion(src_data_vec[idx], *scale); + //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); + + //for (int64_t i = 0; i < loopSize; ++i) { + // auto idx = startIdx + i; + // if (idx >= N) { + // return; + // } + // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); + //} +} + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( + const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, + cache[0] / std::numeric_limits::max()); + } +} + +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + while (i < num_elems) { + out[i] = scaled_fp8_conversion(input[i], *scale); + i += blockDim.x * gridDim.x; + } +} + +} // namespace vllm + +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + }); +} + +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), input.data_ptr(), num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + }); +} + +template +struct call_convert_fp8 +{ + void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) + { + const auto N = src_data.numel() / 2; + //std::cout << N << "\n"; + constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); + constexpr dim3 numThreads{1024, 1, 1}; + auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); + uint32_t actualBlocks = neededBlocks; + + //static uint32_t maxBlocks = 0; + //if (actualBlocks != maxBlocks) { + // maxBlocks = actualBlocks; + // std::cout << actualBlocks << "\n"; + //} + + const dim3 grid{actualBlocks, 1, 1}; + + const auto stream = at::cuda::getCurrentCUDAStream(); + + vllm::convert_fp8_kernel + <<>>(reinterpret_cast(src_data.data_ptr()), + reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); + } +}; + +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale) +{ + torch::Device src_device = src_data.device(); + torch::Device dst_device = dst_data.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + auto t1 = src_data.dtype(); + auto t2 = dst_data.dtype(); + if (src_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp8/convert_kernel.cu b/csrc/quantization/fp8/convert_kernel.cu deleted file mode 100644 index 74636c9df9ffd..0000000000000 --- a/csrc/quantization/fp8/convert_kernel.cu +++ /dev/null @@ -1,96 +0,0 @@ -#include -#include -#include - -#include "../../attention/attention_dtypes.h" -#if defined(ENABLE_FP8_E5M2) -#include "../fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "amd_detail/quant_utils.cuh" -#endif - -namespace vllm -{ - -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) -{ - const int64_t block_idx = blockIdx.x; - - using V_in_vec = typename Vec::Type; - using V_out_vec = typename Vec::Type; - auto dst_data_vec = reinterpret_cast(dst_data); - auto src_data_vec = reinterpret_cast(src_data); - - int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); - auto idx = startIdx; - if (idx >= N) { - return; - } - dst_data_vec[idx] = fp8_e4m3::scaled_vec_conversion(src_data_vec[idx], *scale); - //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); - - //for (int64_t i = 0; i < loopSize; ++i) { - // auto idx = startIdx + i; - // if (idx >= N) { - // return; - // } - // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); - //} -} - -} // namespace vllm - -template -struct call_convert_fp8 -{ - void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) - { - const auto N = src_data.numel() / 2; - //std::cout << N << "\n"; - constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); - constexpr dim3 numThreads{1024, 1, 1}; - auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); - uint32_t actualBlocks = neededBlocks; - - //static uint32_t maxBlocks = 0; - //if (actualBlocks != maxBlocks) { - // maxBlocks = actualBlocks; - // std::cout << actualBlocks << "\n"; - //} - - const dim3 grid{actualBlocks, 1, 1}; - - const auto stream = at::cuda::getCurrentCUDAStream(); - - vllm::convert_fp8_kernel - <<>>(reinterpret_cast(src_data.data_ptr()), - reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); - } -}; - -void convert_fp8(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) -{ - torch::Device src_device = src_data.device(); - torch::Device dst_device = dst_data.device(); - TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") - TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); - at::cuda::OptionalCUDAGuard device_guard(src_device); - auto t1 = src_data.dtype(); - auto t2 = dst_data.dtype(); - if (src_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); - } -} \ No newline at end of file diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000..cde26dbda18cf --- /dev/null +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,570 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 + + #if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; + #endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} + #endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; + #endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } + #endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } + #endif + assert(false); +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh deleted file mode 100644 index 9bcab25db03cf..0000000000000 --- a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh +++ /dev/null @@ -1,277 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "../../attention/attention_dtypes.h" -#include "../../attention/dtype_float32.cuh" -#include "../../attention/dtype_float16.cuh" -#include "../../attention/dtype_bfloat16.cuh" - - -namespace vllm { -#ifdef ENABLE_FP8_E5M2 -namespace fp8_e5m2_unscaled { - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} - -// fp8 -> half -template<> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - return res.x; -} - -// fp8x2 -> half2 -template<> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); - tmp.u16[0] = res.x; - tmp.u16[1] = res.y; - return tmp.u32; -} - -// fp8x4 -> half2x2 -template<> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template<> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; -} - -// fp8 -> __nv_bfloat16 -template<> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - // Note there is no direct convert function from fp8 to bf16. - // fp8 -> half - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - // half -> float -> bf16 - float tmp = half_to_float(res.x); - return __float2bfloat16(tmp); -} - -// fp8x2 -> __nv_bfloat162 -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; -} - -// fp8x4 -> bf16_4_t -template<> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> bf16_8_t -template<> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - -// fp8 -> float -template<> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - // fp8 -> half - uint16_t tmp = vec_conversion(a); - // half -> float - return half_to_float(tmp); -} - -// fp8x2 -> float2 -template<> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ - // fp8x2 -> half2 - uint32_t tmp = vec_conversion(a); - // half2 -> float2 - return half2_to_float2(tmp); -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> float8 -template<> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - - -// half -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; - __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// bf16 -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - assert(false); -#else - __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -#endif -} - -// float -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; -} - - -template<> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; - - float16 = __float22half2_rn(a); - return uint32; -} - -template<> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); - - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - - return b; -} - -template<> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; -} - -template<> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; -} - -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { - __nv_bfloat162 b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { - bf16_4_t b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { - bf16_8_t b; - from_float(b, a); - return b; -} - -} // namespace fp8_e5m2_unscaled -#endif // ENABLE_FP8_E5M2 -} // namespace vllm diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh index 4da0bc6e2df38..1b3fb3d39103f 100644 --- a/csrc/quantization/gptq/compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -9,54 +9,54 @@ namespace vllm { namespace gptq { // atomicAdd for half types, to support CC < 7.x -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); } // atomicAdd for half2 types -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); } // #if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif -#endif + #endif #endif } // namespace gptq diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index eda3436eb5375..2b6719fbdc1bc 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama */ #ifndef _matrix_view_cuh @@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo namespace vllm { namespace gptq { -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } - - __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const - { - half2* ptr = (half2*) item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } }; -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) - { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*) item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } }; -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } }; -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } }; -class MatrixView_q2_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x0f) * 2; - return (data[row * width / 16 + column / 16] >> shift) & 0x03; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - items[2] = (d >> 4) & 0x03; - items[3] = (d >> 6) & 0x03; - } +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } }; -class MatrixView_q3_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int z_w = column * 3 / 32; - int z_mod = column & 0x1f; - - if (z_mod == 10) { - return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); - } else if (z_mod == 21) { - return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); - } else if (z_mod < 10) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; - } else if (z_mod < 21) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; - } else { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; - } +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x1f); - uint32_t d; - if (shift <= 4) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); - } else if (shift == 8) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); - } else if (shift <= 16) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); - } else if (shift == 20) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); - } else { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); - } - items[0] = d & 0x07; - items[1] = (d >> 3) & 0x07; - items[2] = (d >> 6) & 0x07; - items[3] = (d >> 9) & 0x07; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } }; -class MatrixView_q8_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x03) * 8; - return (data[row * width / 4 + column / 4] >> shift) & 0xff; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x03) * 8; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x03) * 2; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - items[2] = (d >> 16) & 0xff; - items[3] = (d >> 24) & 0xff; - } +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } }; } // namespace gptq diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f557..480c4986c3821 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa */ #include @@ -32,2044 +33,1824 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) -#include -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); + #include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); } -#define hipblasHgemm __compat_hipblasHgemm + #define hipblasHgemm __compat_hipblasHgemm -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_hgemm __compat_hipblasHgemm + // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. + #define rocblas_operation_none HIPBLAS_OP_N + #define rocblas_hgemm __compat_hipblasHgemm #endif -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); } -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) -{ - // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 - - float result = {}; - #pragma unroll - for (int i = 0; i < 4; i++) - { - half2 w01 = dq[i]; - float w0 = __low2float(w01); - float w1 = __high2float(w01); - float x0 = __half2float(*a_ptr++); - float x1 = __half2float(*a_ptr++); - result = fma(w0, x0, result); - result = fma(w1, x1, result); - } - float qs = __half2float(qs_h); - result *= qs; - half result_h = __float2half_rn(result); - return __hadd(result_h, g_result); +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); } -__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } -__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int* -); - +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, + const uint32_t*, const half*, + half*, const int, const int, + const int, const int, + const int*); template -__global__ void gemm_half_q_half_gptq_4bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - float scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - // Column result - float block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_2bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 2); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - - b_ptr += size_n; - a_ptr += 16; - } - - k += 16; +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_3bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / 32 * 3; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 32; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_8bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - if (n >= size_n) return; - - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 8); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } + // Zero output + if (n >= size_n) return; - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 8; - } - k += 32; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( - bool first_block, const int m_count, const int bit) -{ - #define SELECT_KERNEL(M_COUNT) \ - if (m_count == M_COUNT) { \ - if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ - if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ - if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ - if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ - } - #if BLOCK_M_SIZE_MAX >= 1 - SELECT_KERNEL(1); - #endif - #if BLOCK_M_SIZE_MAX >= 2 - SELECT_KERNEL(2); - #endif - #if BLOCK_M_SIZE_MAX >= 3 - SELECT_KERNEL(3); - #endif - #if BLOCK_M_SIZE_MAX >= 4 - SELECT_KERNEL(4); - #endif - #if BLOCK_M_SIZE_MAX >= 5 - SELECT_KERNEL(5); - #endif - #if BLOCK_M_SIZE_MAX >= 6 - SELECT_KERNEL(6); - #endif - #if BLOCK_M_SIZE_MAX >= 7 - SELECT_KERNEL(7); - #endif - #if BLOCK_M_SIZE_MAX >= 8 - SELECT_KERNEL(8); - #endif - return NULL; + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; } +void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = + pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} -void gemm_half_q_half_cuda_part -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* c, - int size_m, - int size_n, - int size_k, - int m_count, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - a, - b_q_weight, - b_gptq_qzeros, - b_gptq_scales, - c, - size_m, - size_n, - size_k, - groups, - b_q_perm - ); -} + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; -__global__ void reconstruct_exllama_8bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - // b offset - int qk = offset_k / (32 / 8); + // b offset + int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); - __syncthreads(); + __syncthreads(); - int k = offset_k; - int lk = 0; + int k = offset_k; + int lk = 0; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } - for (int p = 0; p < 4; p++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - k += 32; + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } } + k += 32; + } } -__global__ void reconstruct_exllama_4bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_3bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / 32* 3; + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / 32 * 3; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); - for (int p = 0; p < 1; p++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - if (b_q_perm) - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_2bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / (32 / 2); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / (32 / 2); - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - for (int p = 0; p < 2; p++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); -void reconstruct_exllama -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + __syncthreads(); - auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; - if (bit == 2) { - reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; - } else if (bit == 3) { - reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; - } else if (bit == 8) { - reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - reconstruct_exllama_kernel<<>> - ( - b_q_weight, - b_q_perm, - b_gptq_qzeros, - b_gptq_scales, - height, - width, - groups, - out - ); + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } } +void reconstruct_exllama(const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} __global__ void gemm_half_q_half_alt_4bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 8; - int vec_height = height * 4; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 8; - int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - __shared__ half2 deq2[256][8]; - int val = threadIdx.x / 8; - int off = threadIdx.x % 8; - for (; val < 256; val += BLOCK_KN_SIZE / 8) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0xF), __int2half_rn(val >> 4) - ); + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = + __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - + 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); - } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 8; - int k = 0; - int z_w = w / 8; - int z_mod = (w % 8) * 4; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[4]; - half2 zeros_tmp[4]; - for (int tmp_k = 0; tmp_k < 4; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 4; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } - __global__ void gemm_half_q_half_alt_8bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 4; - int vec_height = height * 2; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 4; - int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 4; - int k = 0; - int z_w = w / 4; - int z_mod = (w % 4) * 8; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[2]; - half2 zeros_tmp[2]; - for (int tmp_k = 0; tmp_k < 2; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); - res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); - res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 2; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } -void gemm_half_q_half_alt -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - int size_m, - int size_n, - int size_k, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); - gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - - auto kernel = gemm_half_q_half_alt_4bit_kernel; - if (bit == 8) { - kernel = gemm_half_q_half_alt_8bit_kernel; - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - (const half2*) a, - b_q_weight, - c, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - size_m, - size_k / 32 * bit, - size_n - ); +void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); } -template -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32 / bit; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - T w_zeros_(w_zeros, group, width); - - uint32_t w_read = w[blockIdx.y * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += bit) - { - int group = g_idx[row + s / bit]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } +template +__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, const int width, + const int group, + half* __restrict__ out) { + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = + __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } } -__global__ void reconstruct_gptq_3bit_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - MatrixView_q3_row w_zeros_(w_zeros, group, width); - - uint32_t w1 = w[(blockIdx.y * 3) * width + column]; - uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; - uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int i = 0; i < 32; i += 1) - { - int group = g_idx[row + i]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - int w_item; - if (i == 10) { - w_item = (w1 >> 30) | ((w2 << 2) & 0x4); - } else if (i == 21) { - w_item = (w2 >> 31) | ((w3 << 1) & 0x6); - } else if (i < 10) { - w_item = ((w1 >> (i * 3)) & 0x7); - } else if (i < 21) { - w_item = ((w2 >> (i * 3 - 32)) & 0x7); - } else { - w_item = ((w3 >> (i * 3 - 64)) & 0x7); - } - *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); - out_ptr += out_.width; +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + half* __restrict__ out) { + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } } -void reconstruct_gptq -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, 32 / bit); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - - auto kernel = reconstruct_gptq_kernel; - if (bit == 2) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 8) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 3) { - kernel = reconstruct_gptq_3bit_kernel; - gridDim.y = DIVIDE(height, 32); - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - b_q_weight, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - height, - width, - groups, - out - ); +void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, half* out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); } - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - half* temp_dq, - int size_m, - int size_n, - int size_k, - int groups, - bool use_exllama, - int bit -) -{ - bool use_reconstruct; +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, half* temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || + (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { - use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } else { - // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now. - use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } - if (use_reconstruct) { - // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); - } - const half alpha = __float2half(1.0f); - const half beta = __float2half(0.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha, temp_dq, size_n, - a, size_k, - &beta, c, size_n); + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); } - else if (use_exllama) - { - // Quantized matmul - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) - { - gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, - groups, bit); - } - if (last_chunk_size) - { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, - b_gptq_scales, b_g_idx, c + last_chunk * size_n, - last_chunk_size, size_n, size_k, last_chunk_size, - groups, bit); - } - } - else - { - gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } } -__global__ void shuffle_4bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } } -__global__ void shuffle_8bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } } -__global__ void shuffle_2bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } } -__global__ void shuffle_3bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } } -__global__ void make_sequential_4bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 3; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_2bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 4; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 16; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 4; - int w2_subrow = source_row & 0x0f; - int w2_row_shift = w2_subrow << 1; - int wnew2_row_shift = i << 1; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000300000003; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_3bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w_column >= w_width) return; - int w_new_row = blockIdx.y * 3; - int q_perm_idx = blockIdx.y << 5; - uint32_t dst[3] = {0, 0, 0}; - - #pragma unroll - for (int i = 0; i < 32; i++) - { - int source_row = q_perm[q_perm_idx++]; - int z_w = (source_row / 32) * 3; - int z_mod = source_row % 32; - int z_bit; - - if (z_mod != 10){ - if (z_mod != 21){ - z_bit = z_mod; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - - uint64_t src; - if (z_mod == 10) { - src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); - } else if (z_mod == 21){ - src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); +__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - src = w[z_w * w_width + w_column]; - src >>= z_bit; - src &= 0x07; + z_bit *= 3; } + } else { + z_w += 1; + } + } - z_w = 0; - if (i != 10){ - if (i != 21){ - z_bit = i; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - if (i == 10) { - dst[z_w] |= (src & 0x03) << 30; - dst[z_w + 1] |= ((src & 0x4) >> 2); - } else if (i == 21) { - dst[z_w] |= (src & 0x01) << 31; - dst[z_w + 1] |= ((src & 0x6) >> 1); + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | + ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | + ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - dst[z_w] |= (src << z_bit); + z_bit *= 3; } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); } - w_new[w_new_row * w_width + w_column] = dst[0]; - w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; - w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; } -__global__ void make_sequential_8bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 2; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 2; - int w2_subrow = source_row & 0x03; - int w2_row_shift = w2_subrow << 3; - int wnew2_row_shift = i << 3; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x000000ff000000ff; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } +void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, + int width, int bit) { + if (q_perm) { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); -void shuffle_exllama_weight -( - uint32_t* q_weight, - int* q_perm, - int height, - int width, - int bit -) -{ - if (q_perm) - { - uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 32 * bit; - - auto kernel = make_sequential_4bit_kernel; - if (bit == 2) { - kernel = make_sequential_2bit_kernel; - } else if (bit == 3) { - kernel = make_sequential_3bit_kernel; - gridDim.y = height / 32; - } else if (bit == 8) { - kernel = make_sequential_8bit_kernel; - } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - q_weight, - new_qweight, - q_perm, - width - ); - // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - // Cleanup - cudaDeviceSynchronize(); - cudaFree(new_qweight); - } dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - auto shuffle_kernel = shuffle_4bit_kernel; + gridDim.y = height / 32 * bit; + + auto kernel = make_sequential_4bit_kernel; if (bit == 2) { - shuffle_kernel = shuffle_2bit_kernel; + kernel = make_sequential_2bit_kernel; } else if (bit == 3) { - shuffle_kernel = shuffle_3bit_kernel; + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; } else if (bit == 8) { - shuffle_kernel = shuffle_8bit_kernel; + kernel = make_sequential_8bit_kernel; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(q_weight, height, width); + kernel<<>>(q_weight, new_qweight, q_perm, + width); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, + height / 32 * bit * width * sizeof(uint32_t), + cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm -( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); - - vllm::gptq::gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - (const uint32_t*) b_q_weight.data_ptr(), - (const uint32_t*)b_gptq_qzeros.data_ptr(), - (const half*) b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), - (half*) c.data_ptr(), - (half*) temp_dq.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - b_gptq_qzeros.size(0), // group number - use_exllama, - bit - ); - return c; +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty( + {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + (const uint32_t*)b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*)b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama, bit); + return c; } -void gptq_shuffle -( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); - vllm::gptq::shuffle_exllama_weight( - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 32 / bit, - q_weight.size(1), - bit - ); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*)q_weight.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 + ? NULL + : (int*)q_perm.data_ptr(), + q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh index 295872a91de37..ca0f810608d1b 100644 --- a/csrc/quantization/gptq/qdq_2.cuh +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -14,71 +14,60 @@ namespace gptq { // // ffddbb99 77553311 eeccaa88 66442200 -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; } -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z4 = __half2half2(z4_); - const half2 z16 = __half2half2(z16_); - const half2 z64 = __half2half2(z64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh index 3e7ecde752ba3..0d5c2adf5dbbe 100644 --- a/csrc/quantization/gptq/qdq_3.cuh +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -11,128 +11,136 @@ namespace gptq { // vjjjhhhf ffdddbbb uiiiggge eecccaaa // vtttrrrp ppnnnlll usssqqqo oommmkkk -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y8, z8); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hfma2( q3.as_half2, y8, z8); - dq[ 4] = __hfma2( q4.as_half2, y64, z64); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hfma2( q6.as_half2, y8, z8); - dq[ 7] = __hadd2( q7.as_half2, z1); - dq[ 8] = __hfma2( q8.as_half2, y8, z8); - dq[ 9] = __hfma2( q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index 881f353f6564d..7f65d2d2819b1 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -13,133 +13,112 @@ namespace gptq { // // 77775555 33331111 66664444 22220000 -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z16 = __half2half2(z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1z16)[2], - half2 (&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - half2 scale2 = __half2half2(scale); + half2 scale2 = __half2half2(scale); - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1z16)[2], - half2(&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); } - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1z16)[2], - half2 (&y1y16)[2], - int stride, - bool scaled -) -{ - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) - { - dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } - else - { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) - } +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } } } // namespace gptq } // namespace vllm diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh index 0c7ad7876140b..feb5d220424b0 100644 --- a/csrc/quantization/gptq/qdq_8.cuh +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -__forceinline__ __device__ void shuffle_8bit_4 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_8bit_8 -( - const uint32_t q_0, - const uint32_t q_1, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh index 1722a9aa6cb34..9426408fec502 100644 --- a/csrc/quantization/gptq/qdq_util.cuh +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -union half2_uint32 -{ - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} }; -union half_uint16 -{ - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} }; // Max_scale premultiplied by 1/256 -__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) -{ - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; } -__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) -{ - return __hmul(__int2half_rn(q - qzero), scale); +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); } -__forceinline__ __device__ half dq_ns(const int q, const int qzero) -{ - //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); } -__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) -{ - return (int)((q >> shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); } -__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) -{ - return (int)(__funnelshift_rc(q0, q1, shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); } } // namespace gptq diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu new file mode 100644 index 0000000000000..c573b9041065b --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -0,0 +1,1870 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace gptq_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || + (g_idx.size(0) == size_k && perm.size(0) == size_k), + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(0) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, + is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh new file mode 100644 index 0000000000000..ba5368ea8835f --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -0,0 +1,76 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace gptq_marlin { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh new file mode 100644 index 0000000000000..ca1b7099d6ec7 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -0,0 +1,77 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + +namespace gptq_marlin { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace gptq_marlin + +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu new file mode 100644 index 0000000000000..4adc158eb14ea --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -0,0 +1,350 @@ +#include "gptq_marlin.cuh" + +namespace gptq_marlin { + +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + + #pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); + TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", pack_factor = ", pack_factor); + TORCH_CHECK(b_q_weight.size(1) == size_n, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); + } + + return out; +} + +#endif diff --git a/csrc/quantization/marlin/LICENSE b/csrc/quantization/marlin/dense/LICENSE similarity index 100% rename from csrc/quantization/marlin/LICENSE rename to csrc/quantization/marlin/dense/LICENSE diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu similarity index 73% rename from csrc/quantization/marlin/marlin_cuda_kernel.cu rename to csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index cf1b0afdec8b4..03d66cecedf1f 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -25,7 +25,10 @@ #include -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin { @@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; @@ -51,34 +55,30 @@ using I4 = Vec; using FragA = Vec; using FragB = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -// Asynchronous global->shared copy with a cache hint indicating that the values -// may be evicted immediately; used for quantized weights B, which are only -// accessed precisely once and should thus not pollute the L2 cache which we -// need for inputs A and outputs C. -__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) { +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } @@ -89,28 +89,30 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -120,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -145,24 +148,24 @@ __device__ inline FragB dequant(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -177,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -194,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { } } -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -248,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -268,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -300,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; init_slice(); - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time // constant constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory + 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = - a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads + (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile + a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -375,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -394,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -410,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -428,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -482,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can // avoid doing so for each weight. - if (group_blocks != -1) - scale(frag_b0, frag_s[k % 2][j], 0); + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) - scale(frag_b1, frag_s[k % 2][j], 1); -#pragma unroll + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -537,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -578,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -599,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + __half2float(reinterpret_cast<__half*>(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -665,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { + auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == - -1) // for per-column quantization we finally apply the scale here + -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2 *)sh)[idx] = res; + ((half2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -692,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -706,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -718,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { @@ -735,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk matmul(k); } slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -749,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if (group_blocks == -1 && last) { - if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); @@ -758,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -777,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -794,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -826,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; @@ -838,7 +832,7 @@ static constexpr int tile_size = 16; static constexpr int max_par = 16; static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit + 8; // We have 8 4-bit vals inside a 32 bit #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ GROUP_BLOCKS, NUM_THREADS) \ @@ -865,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || @@ -914,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, } thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { @@ -933,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) -void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, - int prob_n, int prob_k, void *workspace, int groupsize = -1, +void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16) { int tot_m = prob_m; @@ -1003,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, " is not divisible by group_blocks = ", group_blocks); } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; @@ -1018,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; @@ -1048,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, } } -} // namespace marlin +} // namespace marlin -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M TORCH_CHECK(size_m == a.size(0), "Shape mismatch: a.size(0) = " + str(a.size(0)) + @@ -1081,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); diff --git a/csrc/quantization/marlin/sparse/LICENSE b/csrc/quantization/marlin/sparse/LICENSE new file mode 100644 index 0000000000000..ca75fb15e660a --- /dev/null +++ b/csrc/quantization/marlin/sparse/LICENSE @@ -0,0 +1,203 @@ +Contains code from https://github.com/IST-DASLab/Sparse-Marlin/ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h new file mode 100644 index 0000000000000..16018d331bec2 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace marlin_24 { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +template +struct ShapeBase { + static constexpr int M = M_, N = N_, K = K_; +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragM = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h new file mode 100644 index 0000000000000..83e3578d2f511 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, + bool pred = true, + const bool zfill = false) { + const int BYTES = 16; + int src_in_bytes = (zfill ? 0 : BYTES); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +} + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h new file mode 100644 index 0000000000000..45ab67a78a1de --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -0,0 +1,180 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { + +// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, + const int psel) { + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + float* c = reinterpret_cast(&frag_c); + if (psel == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); + } else { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, + float c3) { + uint2 r; + asm("{\n\t" + ".reg .f16 a, b, c, d; \n\t" + "cvt.rn.f16.f32 a, %2; \n\t" + "cvt.rn.f16.f32 b, %3; \n\t" + "cvt.rn.f16.f32 c, %4; \n\t" + "cvt.rn.f16.f32 d, %5; \n\t" + "mov.b32 %0, {a, b}; \n\t" + "mov.b32 %1, {c, d}; \n\t" + "}" + : "=r"(r.x), "=r"(r.y) + : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + return r; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { + *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); + *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); + *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); + *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); + + *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); + *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); + *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); + *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); +} + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu new file mode 100644 index 0000000000000..686dd7851e6af --- /dev/null +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -0,0 +1,1125 @@ +/* + * Notice: This file was modified by Neuralmagic inc to include 8-bit support + * + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include "common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +#else + + #include "common/mem.h" + #include "common/mma.h" + +#endif + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_24 { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int THREADS = 256; +static constexpr int STAGES = 4; + +static constexpr int min_thread_n = 128; + +static constexpr int tile_size = 16; +static constexpr int max_par = 64; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + // number of thread_k_blocks in k-dim + int k_tiles = prob_k / 32 / thread_k_blocks; + // number of thread_n_blocks in n-dim + int n_tiles = prob_n / 16 / thread_n_blocks; + // iters needed to cover all slices + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + // number of threadblock tiles in the current slice + int slice_iters; + // total number of active threadblocks in the current slice + int slice_count = 0; + // index of threadblock in current slice; numbered bottom to top + int slice_idx; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 32 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + constexpr int pack_factor = 32 / num_bits; + + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4* meta_ptr[m_sh_iters]; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2][b_thread_vecs]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + + // Load meta with ldsm4 + int4* sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2* tmp = (half2*)&res; + // for per-column quantization we finally apply the scale here + if constexpr (group_blocks == -1 && num_bits == 4) { + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4*)sh)[idx] = *((int4*)&res[0]); + }; + + // RLC: only warp 0 and 1 baseline example + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], + frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], + frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], + frag_c[i][3][0][3], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], + frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], + frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], + frag_c[i][3][1][2], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], + frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], + frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], + frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); + + c_sh_wr += 8 * c_sh_stride_2; + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + matmul(pipe); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + + pipe++; + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + } + } + thread_block_reduce(); + + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], + &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], + &frag_c[i][0][0][2], &frag_c[i][1][0][2], + &frag_c[i][2][0][2], &frag_c[i][3][0][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], + &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], + &frag_c[i][0][0][3], &frag_c[i][1][0][3], + &frag_c[i][2][0][3], &frag_c[i][3][0][3], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], + &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], + &frag_c[i][0][1][2], &frag_c[i][1][1][2], + &frag_c[i][2][1][2], &frag_c[i][3][1][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], + &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], + &frag_c[i][0][1][3], &frag_c[i][1][1][3], + &frag_c[i][2][1][3], &frag_c[i][3][1][3], + frag_s[0][2]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#endif + +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ + } + +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + TORCH_CHECK(sms > 0); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (thread_k == -1 || thread_m == -1) { + if (prob_n <= 16) { + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization + thread_k = 128; + thread_m = 128; + } else if (prob_n <= 256) { + thread_k = 64; + thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; + } + } + + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, + " is not divisible by thread_m = ", thread_m); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, + " is not divisible by group_blocks = ", group_blocks); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + constexpr int max_m_blocks = 4; + + int* locks = (int*)workspace; + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + if (thread_n_blocks > max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); + if (par > max_par) par = max_par; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + + // the false is start of the CALL_IF macros + if (false) { + } // BMxBNxBK, group + // 4-bit + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 2, 2, 4) + CALL_IF_2_4(4, 16, 3, 2, -1) + CALL_IF_2_4(4, 16, 3, 2, 4) + CALL_IF_2_4(4, 16, 4, 2, -1) + CALL_IF_2_4(4, 16, 4, 2, 4) + + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + + // 8-bit + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 2, 2, 4) + CALL_IF_2_4(8, 16, 3, 2, -1) + CALL_IF_2_4(8, 16, 3, 2, 4) + CALL_IF_2_4(8, 16, 4, 2, -1) + CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } +} + +} // namespace marlin_24 + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin_24::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_24::tile_size)); + TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin_24::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_24::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_24::tile_size)); + + int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); + + // Verify meta + TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, + "b_meta.size(0) = ", b_meta.size(0), + " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); + TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), + " is not size_n * 2 = ", size_n * 2); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify b_meta device and strides + TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); + TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + int thread_k = -1; + int thread_m = -1; + int sms = -1; + int max_par = marlin_24::max_par; + + int groupsize = -1; + if (b_scales.size(0) > 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + groupsize = size_k / b_scales.size(0); + groupsize /= 2; // Because of 24 + } + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 64, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_24::min_thread_n)); + int min_workspace_size = + (size_n / marlin_24::min_thread_n) * marlin_24::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin_24::marlin_cuda_2_4( + a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), + num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_m, sms, max_par); + + return c; +} diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 09964903622b4..1b339fa4b392b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM - const half2* __restrict__ vec, + const half2* __restrict__ vec, #else - const __half2* __restrict__ vec, + const __half2* __restrict__ vec, #endif - const int* __restrict__ mat, + const int* __restrict__ mat, #ifndef USE_ROCM - half2* __restrict__ mul, + half2* __restrict__ mul, #else - float2* __restrict__ mul, + float2* __restrict__ mul, #endif - const __half* __restrict__ lookup_table, - int height, - int width, - int batch, - int vec_height -) { + const __half* __restrict__ lookup_table, int height, int width, int batch, + int vec_height) { const int blockwidth2 = BLOCKWIDTH / 2; int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; #ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; @@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel( unsigned int tmp1; unsigned int lut_index1, lut_index2; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b) { i = width * row + col; res = __int2half_rd(0); k = 0; __syncthreads(); if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; + blockvec[threadIdx.x] = + vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + + threadIdx.x]; __syncthreads(); while (k < blockwidth2) { @@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); #else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), + res); #endif i += width; @@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel( } } -} // namespace squeezellm -} // namespace vllm +} // namespace squeezellm +} // namespace vllm // 4-bit matvec kernel (LUT-based) -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table -) { +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table) { int height = mat.size(0); int width = mat.size(1); int batch = vec.size(0); int vec_height = vec.size(1); - dim3 blocks( - (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH - ); + dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH); dim3 threads(BLOCKWIDTH); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM - (half2*) vec.data(), + (half2*)vec.data(), #else - (__half2*) vec.data_ptr(), + (__half2*)vec.data_ptr(), #endif - mat.data_ptr(), + mat.data_ptr(), #ifndef USE_ROCM - (half2*) mul.data(), - (__half*) lookup_table.data(), + (half2*)mul.data(), (__half*)lookup_table.data(), #else - (float2*) mul.data_ptr(), - (__half*) lookup_table.data_ptr(), + (float2*)mul.data_ptr(), + (__half*)lookup_table.data_ptr(), #endif - height, width, batch, vec_height - ); + height, width, batch, vec_height); } #undef BLOCKWIDTH diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 392b74d1bb73a..9af4aae516151 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -20,12 +21,12 @@ #include "cuda_compat.h" namespace vllm { -template +template __inline__ __device__ T warpReduceSum(T val) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); - #pragma unroll +#pragma unroll for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; @@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) { } /* Calculate the sum of all elements in a block */ -template +template __inline__ __device__ T blockReduceSum(T val) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduceSum(val); - // Calculates max number of lanes that need to participate in the last warpReduce + // Calculates max number of lanes that need to participate in the last + // warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) - shared[wid] = val; + if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); val = warpReduceSum(val); } else { // A single warpReduce is equal to blockReduce @@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 0e76763a87b7c..ed569816200ee 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -10,3 +10,4 @@ pydantic torch py-cpuinfo transformers +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/docs/source/assets/dev/dockerfile-stages-dependency.png b/docs/source/assets/dev/dockerfile-stages-dependency.png new file mode 100644 index 0000000000000..b016531f1e0a0 Binary files /dev/null and b/docs/source/assets/dev/dockerfile-stages-dependency.png differ diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst new file mode 100644 index 0000000000000..f371194781de3 --- /dev/null +++ b/docs/source/community/meetups.rst @@ -0,0 +1,12 @@ +.. _meetups: + +vLLM Meetups +============ + +We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: + +- `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ +- `The second vLLM meetup `__, with IBM Research, January 31st 2024. `[Slides] `__ `[Video (vLLM Update)] `__ `[Video (IBM Research & torch.compile)] `__ +- `The first vLLM meetup `__, with a16z, October 5th 2023. `[Slides] `__ + +We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `__. diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md new file mode 100644 index 0000000000000..d167b66267a4d --- /dev/null +++ b/docs/source/community/sponsors.md @@ -0,0 +1,25 @@ +# Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Dropbox +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 44cda7c99cdd5..cfebc2ff9bb33 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,11 +11,14 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import logging +import os import sys +from typing import List from sphinx.ext import autodoc logger = logging.getLogger(__name__) +sys.path.append(os.path.abspath("../..")) # -- Project information ----------------------------------------------------- @@ -45,7 +48,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: List[str] = ["**/*.template.rst"] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " @@ -70,7 +73,14 @@ # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] -# Mock out external dependencies here. + +# Generate additional rst documentation here. +def setup(app): + from docs.source.generate_examples import generate_examples + generate_examples() + + +# Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ "cpuinfo", "torch", @@ -82,14 +92,16 @@ "vllm._C", "numpy", "tqdm", + "tensorizer", ] for mock_target in autodoc_mock_imports: if mock_target in sys.modules: logger.info( - f"Potentially problematic mock target ({mock_target}) found; " + "Potentially problematic mock target (%s) found; " "autodoc_mock_imports cannot mock modules that have already " - "been loaded into sys.modules when the sphinx build starts.") + "been loaded into sys.modules when the sphinx build starts.", + mock_target) class MockedClassDocumenter(autodoc.ClassDocumenter): @@ -103,4 +115,15 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: autodoc.ClassDocumenter = MockedClassDocumenter +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'typing_extensions': + ('https://typing-extensions.readthedocs.io/en/latest', None), + 'numpy': ('https://numpy.org/doc/stable', None), + 'torch': ('https://pytorch.org/docs/stable', None), + 'psutil': ('https://psutil.readthedocs.io/en/stable', None), +} + +autodoc_preserve_defaults = True + navigation_with_keys = False diff --git a/docs/source/dev/dockerfile/dockerfile.rst b/docs/source/dev/dockerfile/dockerfile.rst new file mode 100644 index 0000000000000..a07463392dbe8 --- /dev/null +++ b/docs/source/dev/dockerfile/dockerfile.rst @@ -0,0 +1,50 @@ +Dockerfile +==================== + +See `here `_ for the main Dockerfile to construct +the image for running an OpenAI compatible server with vLLM. + +- Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: + + - All build stages + - The default build target (highlighted in grey) + - External images (with dashed borders) + + The edges of the build graph represent: + + - FROM ... dependencies (with a solid line and a full arrow head) + - COPY --from=... dependencies (with a dashed line and an empty arrow head) + - RUN --mount=(.*)from=... dependencies (with a dotted line and an empty diamond arrow head) + + .. figure:: ../../assets/dev/dockerfile-stages-dependency.png + :alt: query + :width: 100% + :align: center + + Made using: https://github.com/patrickhoefler/dockerfilegraph + + Commands to regenerate the build graph (make sure to run it **from the `root` directory of the vLLM repository** where the dockerfile is present): + + .. code:: bash + + dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename Dockerfile + + or in case you want to run it directly with the docker image: + + .. code:: bash + + docker run \ + --rm \ + --user "$(id -u):$(id -g)" \ + --workdir /workspace \ + --volume "$(pwd)":/workspace \ + ghcr.io/patrickhoefler/dockerfilegraph:alpine \ + --output png \ + --dpi 200 \ + --max-label-length 50 \ + --filename Dockerfile \ + --legend + + (To run it for a different file, you can pass in a different argument to the flag `--filename`.) + + \ No newline at end of file diff --git a/docs/source/dev/engine/async_llm_engine.rst b/docs/source/dev/engine/async_llm_engine.rst index 47db1e0a401b1..93fc310cb543b 100644 --- a/docs/source/dev/engine/async_llm_engine.rst +++ b/docs/source/dev/engine/async_llm_engine.rst @@ -1,7 +1,6 @@ - AsyncLLMEngine ================================= -.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine - :members: generate, abort +.. autoclass:: vllm.AsyncLLMEngine + :members: :show-inheritance: diff --git a/docs/source/dev/engine/llm_engine.rst b/docs/source/dev/engine/llm_engine.rst index 1de6d7adc87c6..0b8c1e219d7c9 100644 --- a/docs/source/dev/engine/llm_engine.rst +++ b/docs/source/dev/engine/llm_engine.rst @@ -1,6 +1,6 @@ LLMEngine ================================= -.. autoclass:: vllm.engine.llm_engine.LLMEngine - :members: add_request, abort_request, step - :show-inheritance: \ No newline at end of file +.. autoclass:: vllm.LLMEngine + :members: + :show-inheritance: diff --git a/docs/source/dev/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst new file mode 100644 index 0000000000000..83ba1b6987c6d --- /dev/null +++ b/docs/source/dev/offline_inference/llm.rst @@ -0,0 +1,6 @@ +LLM Class +========= + +.. autoclass:: vllm.LLM + :members: + :show-inheritance: diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst new file mode 100644 index 0000000000000..31c3d16a3c8eb --- /dev/null +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 0000000000000..27dfb0e9df90e --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/dev/sampling_params.rst b/docs/source/dev/sampling_params.rst index 844859b3ec1f0..f645941a6c022 100644 --- a/docs/source/dev/sampling_params.rst +++ b/docs/source/dev/sampling_params.rst @@ -1,4 +1,5 @@ -Sampling Params -=============== +Sampling Parameters +=================== -.. automodule:: vllm.sampling_params.SamplingParams \ No newline at end of file +.. autoclass:: vllm.SamplingParams + :members: diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py new file mode 100644 index 0000000000000..79b49a186236a --- /dev/null +++ b/docs/source/generate_examples.py @@ -0,0 +1,61 @@ +import re +from pathlib import Path + + +def fix_case(text: str) -> str: + subs = [ + ("api", "API"), + ("llm", "LLM"), + ("vllm", "vLLM"), + ("openai", "OpenAI"), + ("multilora", "MultiLoRA"), + ] + for sub in subs: + text = re.sub(*sub, text, flags=re.IGNORECASE) + return text + + +def underline(title: str, character: str = "=") -> str: + return f"{title}\n{character * len(title)}" + + +def generate_title(filename: str) -> str: + # Turn filename into a title + title = filename.replace("_", " ").title() + # Handle acronyms and names + title = fix_case(title) + # Underline title + title = underline(title) + return title + + +def generate_examples(): + root_dir = Path(__file__).parent.parent.parent.resolve() + + # Source paths + script_dir = root_dir / "examples" + script_paths = sorted(script_dir.glob("*.py")) + + # Destination paths + doc_dir = root_dir / "docs/source/getting_started/examples" + doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths] + + # Generate the example docs for each example script + for script_path, doc_path in zip(script_paths, doc_paths): + script_url = f"https://github.com/vllm-project/vllm/blob/main/examples/{script_path.name}" + # Make script_path relative to doc_path and call it include_path + include_path = '../../../..' / script_path.relative_to(root_dir) + content = (f"{generate_title(doc_path.stem)}\n\n" + f"Source {script_url}.\n\n" + f".. literalinclude:: {include_path}\n" + " :language: python\n" + " :linenos:\n") + with open(doc_path, "w+") as f: + f.write(content) + + # Generate the toctree for the example scripts + with open(doc_dir / "examples_index.template.rst") as f: + examples_index = f.read() + with open(doc_dir / "examples_index.rst", "w+") as f: + example_docs = "\n ".join(path.stem for path in script_paths) + f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs)) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 3d736bf7120ec..61fcd45a26347 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,9 +3,7 @@ Installation with ROCm ====================== -vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm. -At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. -Data types currently supported in ROCm are FP16 and BF16. +vLLM supports AMD GPUs with ROCm 5.7 and 6.0. Requirements ------------ @@ -13,114 +11,57 @@ Requirements * OS: Linux * Python: 3.8 -- 3.11 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* Pytorch 2.0.1/2.1.1/2.2 -* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) +* ROCm 6.0 and ROCm 5.7 Installation options: -#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image ` -#. :ref:`Build from source ` #. :ref:`Build from source with docker ` +#. :ref:`Build from source ` -.. _quick_start_docker_rocm: - -(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image ---------------------------------------------------------------------------- - -This option is for ROCm 5.7 only: - -.. code-block:: console - - $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 - $ docker run -it \ - --network=host \ - --group-add=video \ - --ipc=host \ - --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --device /dev/kfd \ - --device /dev/dri \ - -v :/app/model \ - embeddedllminfo/vllm-rocm \ - bash - - -.. _build_from_source_rocm: - -Option 2: Build from source ---------------------------- - -You can build and install vLLM from source: - -Below instruction is for ROCm 5.7 only. -At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website. - -0. Install prerequisites (skip if you are already in an environment/docker with the following installed): - -- `ROCm `_ -- `Pytorch `_ - - .. code-block:: console - - $ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version - - -1. Install `flash attention for ROCm `_ - - Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ - -.. note:: - - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - -2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention +.. _build_from_source_docker_rocm: - .. code-block:: console +Option 1: Build from source with docker (recommended) +----------------------------------------------------- - $ pip install xformers==0.0.23 --no-deps - $ bash patch_xformers.rocm.sh +You can build and install vLLM from source. -3. Build vLLM. +First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. - .. code-block:: console +`Dockerfile.rocm `_ uses ROCm 6.0 by default, but also supports ROCm 5.7. +It provides flexibility to customize the build of docker image using the following arguments: - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation +* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` +* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. +* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` +* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo `_. The default is `ae7928c` +* `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1. +Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -.. _build_from_source_docker_rocm: -Option 3: Build from source with docker ------------------------------------------------------ +To build vllm on ROCm 6.0 for MI200 and MI300 series, you can use the default: -You can build and install vLLM from source: +.. code-block:: console -Build a docker image from `Dockerfile.rocm`, and launch a docker container. + $ docker build -f Dockerfile.rocm -t vllm-rocm . -The `Dockerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments: +To build vllm on ROCm 6.0 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: -* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` -* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` -* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo `_. The default is `3d2b6f5` -* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. +.. code-block:: console -Their values can be passed in when running ``docker build`` with ``--build-arg`` options. + $ docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm . -For example, to build docker image for vllm on ROCm 5.7, you can run: +To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below: .. code-block:: console $ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.0, you can use the default: +To run the above docker image ``vllm-rocm``, use the below command: .. code-block:: console - $ docker build -f Dockerfile.rocm -t vllm-rocm . $ docker run -it \ --network=host \ --group-add=video \ @@ -133,7 +74,13 @@ To build vllm on ROCm 6.0, you can use the default: vllm-rocm \ bash -Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below: +Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models. + + +.. _build_from_source_rocm: + +Option 2: Build from source +--------------------------- 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): @@ -141,32 +88,50 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - `Pytorch `_ - `hipBLAS `_ -1. Install `flash attention for ROCm `_ +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`. - Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ +Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_ + +For rocm6.0: + +.. code-block:: console + + $ pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.0 + + +For rocm5.7: + +.. code-block:: console + + $ pip install torch --index-url https://download.pytorch.org/whl/rocm5.7 + + +1. Install `Triton flash attention for ROCm `_ + +Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ + +2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ + +Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention `_ .. note:: - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. + - If you fail to install `ROCm/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention - - .. code-block:: console - - $ pip install xformers==0.0.23 --no-deps - $ bash patch_xformers.rocm.sh - 3. Build vLLM. - .. code-block:: console +.. code-block:: console - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py install # This may take 5-10 minutes. + $ cd vllm + $ pip install -U -r requirements-rocm.txt + $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation -.. note:: - - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. +.. tip:: + - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention. + - The ROCm version of pytorch, ideally, should match the ROCm driver version. diff --git a/docs/source/getting_started/examples/examples_index.template.rst b/docs/source/getting_started/examples/examples_index.template.rst new file mode 100644 index 0000000000000..1b34cccbae15a --- /dev/null +++ b/docs/source/getting_started/examples/examples_index.template.rst @@ -0,0 +1,8 @@ +Examples +================================= + +.. toctree:: + :maxdepth: 1 + :caption: Scripts + + %EXAMPLE_DOCS% diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 5dfb32080f97a..ba23e7468dcc1 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -53,8 +53,13 @@ You can also build and install vLLM from source: $ git clone https://github.com/vllm-project/vllm.git $ cd vllm + $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. +.. tip:: + + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs to be run simultaneously, via the environment variable `MAX_JOBS`. For example: @@ -85,13 +90,3 @@ You can also build and install vLLM from source: $ nvcc --version # verify that nvcc is in your PATH $ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME - -.. note:: - If you are developing the C++ backend of vLLM, consider building vLLM with - - .. code-block:: console - - $ python setup.py develop - - since it will give you incremental builds. The downside is that this method - is `deprecated by setuptools `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5d5d52696ba34..5f18fe9ae0a73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,7 @@ For more information, check out the following: * `vLLM announcing blog post `_ (intro to PagedAttention) * `vLLM paper `_ (SOSP 2023) * `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency `_ by Cade Daniel et al. +* :ref:`vLLM Meetups `. @@ -65,6 +66,7 @@ Documentation getting_started/neuron-installation getting_started/cpu-installation getting_started/quickstart + getting_started/examples/examples_index .. toctree:: :maxdepth: 1 @@ -74,6 +76,7 @@ Documentation serving/deploying_with_docker serving/distributed_serving serving/metrics + serving/env_vars serving/usage_stats serving/integrations @@ -85,6 +88,7 @@ Documentation models/adding_model models/engine_args models/lora + models/performance .. toctree:: :maxdepth: 1 @@ -97,10 +101,19 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention + dev/dockerfile/dockerfile + +.. toctree:: + :maxdepth: 2 + :caption: Community + + community/meetups + community/sponsors Indices and tables ================== diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index a82c2cef10e83..cbc8099e6f70f 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -95,7 +95,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a 5. Register your model ---------------------- -Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. +Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py `_. 6. Out-of-Tree Model Integration -------------------------------------------- diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d8a7ac72e0175..bdf566d3ebbd1 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -5,132 +5,19 @@ Engine Arguments Below, you can find an explanation of every engine argument for vLLM: -.. option:: --model - - Name or path of the huggingface model to use. - -.. option:: --tokenizer - - Name or path of the huggingface tokenizer to use. - -.. option:: --revision - - The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - -.. option:: --tokenizer-revision - - The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - -.. option:: --tokenizer-mode {auto,slow} - - The tokenizer mode. - - * "auto" will use the fast tokenizer if available. - * "slow" will always use the slow tokenizer. - -.. option:: --trust-remote-code - - Trust remote code from huggingface. - -.. option:: --download-dir - - Directory to download and load the weights, default to the default cache dir of huggingface. - -.. option:: --load-format {auto,pt,safetensors,npcache,dummy} - - The format of the model weights to load. - - * "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. - * "pt" will load the weights in the pytorch bin format. - * "safetensors" will load the weights in the safetensors format. - * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. - * "dummy" will initialize the weights with random values, mainly for profiling. - -.. option:: --dtype {auto,half,float16,bfloat16,float,float32} - - Data type for model weights and activations. - - * "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. - * "half" for FP16. Recommended for AWQ quantization. - * "float16" is the same as "half". - * "bfloat16" for a balance between precision and range. - * "float" is shorthand for FP32 precision. - * "float32" for FP32 precision. - -.. option:: --max-model-len - - Model context length. If unspecified, will be automatically derived from the model config. - -.. option:: --worker-use-ray - - Use Ray for distributed serving, will be automatically set when using more than 1 GPU. - -.. option:: --pipeline-parallel-size (-pp) - - Number of pipeline stages. - -.. option:: --tensor-parallel-size (-tp) - - Number of tensor parallel replicas. - -.. option:: --max-parallel-loading-workers - - Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models. - -.. option:: --block-size {8,16,32} - - Token block size for contiguous chunks of tokens. - -.. option:: --enable-prefix-caching - - Enables automatic prefix caching - -.. option:: --seed - - Random seed for operations. - -.. option:: --swap-space - - CPU swap space size (GiB) per GPU. - -.. option:: --gpu-memory-utilization - - The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. - For example, a value of 0.5 would imply 50% GPU memory utilization. - If unspecified, will use the default value of 0.9. - -.. option:: --max-num-batched-tokens - - Maximum number of batched tokens per iteration. - -.. option:: --max-num-seqs - - Maximum number of sequences per iteration. - -.. option:: --max-paddings - - Maximum number of paddings in a batch. - -.. option:: --disable-log-stats - - Disable logging statistics. - -.. option:: --quantization (-q) {awq,squeezellm,None} - - Method used to quantize the weights. +.. argparse:: + :module: vllm.engine.arg_utils + :func: _engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: Async Engine Arguments ---------------------- -Below are the additional arguments related to the asynchronous engine: - -.. option:: --engine-use-ray - Use Ray to start the LLM engine in a separate process as the server process. - -.. option:: --disable-log-requests - - Disable logging requests. - -.. option:: --max-log-len +Below are the additional arguments related to the asynchronous engine: - Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited. \ No newline at end of file +.. argparse:: + :module: vllm.engine.arg_utils + :func: _async_engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: \ No newline at end of file diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst new file mode 100644 index 0000000000000..d8750ddc34e8e --- /dev/null +++ b/docs/source/models/performance.rst @@ -0,0 +1,63 @@ +.. _performance: + +Performance and Tuning +====================== + +Preemption +---------- +Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes +available again. When this occurs, the following warning is printed: + +``` +WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 +``` + +While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency. +If you frequently encounter preemptions from the vLLM engine, consider the following actions: + +- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. +- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. +- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. + +You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. + +Chunked Prefill +--------------- +vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. + +You can enable the feature by specifying ``--enable-chunked-prefill`` in the command line or setting ``enable_chunked_prefill=True`` in the LLM constructor. + +.. code-block:: python + + llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) + # Set max_num_batched_tokens to tune performance. + # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. + # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) + +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. +This policy optimizes the TTFT (time to the first token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. + +Once chunked prefill is enabled, the policy is changed to prioritize decode requests. +It batches all pending decode requests to the batch before scheduling any prefill. +When there are available token_budget (``max_num_batched_tokens``), it schedules pending prefills. +If a last pending prefill request cannot fit into ``max_num_batched_tokens``, it chunks it. + +This policy has two benefits: + +- It improves ITL and generation decode because decode requests are prioritized. +- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. + +You can tune the performance by changing ``max_num_batched_tokens``. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B). +Smaller ``max_num_batched_tokens`` achieves better ITL because there are fewer prefills interrupting decodes. +Higher ``max_num_batched_tokens`` achieves better TTFT as you can put more prefill to the batch. + +- If ``max_num_batched_tokens`` is the same as ``max_model_len``, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +- Note that the default value (512) of ``max_num_batched_tokens`` is optimized for ITL, and it may have lower throughput than the default scheduler. + +We recommend you set ``max_num_batched_tokens > 2048`` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). + +Please try out this feature and let us know your feedback via GitHub issues! \ No newline at end of file diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e7bfdcb65316e..82e71e61975c8 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -16,13 +16,21 @@ Alongside each architecture, we include some popular models that use it. - Example HuggingFace Models - :ref:`LoRA ` * - :code:`AquilaForCausalLM` - - Aquila + - Aquila & Aquila2 - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. - ✅︎ + * - :code:`ArcticForCausalLM` + - Arctic + - :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc. + - * - :code:`BaiChuanForCausalLM` - - Baichuan + - Baichuan & Baichuan2 - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. - ✅︎ + * - :code:`BloomForCausalLM` + - BLOOM, BLOOMZ, BLOOMChat + - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. + - * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. @@ -30,23 +38,19 @@ Alongside each architecture, we include some popular models that use it. * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - - + - * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - - * - :code:`BloomForCausalLM` - - BLOOM, BLOOMZ, BLOOMChat - - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - + - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -54,19 +58,19 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - - + - * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - - + - ✅︎ * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - - + - * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - - + - * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -80,8 +84,8 @@ Alongside each architecture, we include some popular models that use it. - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - * - :code:`LlamaForCausalLM` - - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. + - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi + - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM @@ -93,32 +97,40 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - ✅︎ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - - + - * - :code:`OLMoForCausalLM` - OLMo - - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. - - + - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. + - * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - - + - * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - - + - * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - ✅︎ + * - :code:`Phi3ForCausalLM` + - Phi-3 + - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. + - + * - :code:`Phi3SmallForCausalLM` + - Phi-3-Small + - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - + - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. @@ -126,11 +138,20 @@ Alongside each architecture, we include some popular models that use it. * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - - + - * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - - + - + * - :code:`Starcoder2ForCausalLM` + - Starcoder2 + - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. + - + * - :code:`XverseForCausalLM` + - Xverse + - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. + - + If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. @@ -168,3 +189,29 @@ Alternatively, you can raise an issue on our `GitHub `_ and `test_big_models.py `_ for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst index 7ec769630300d..fa82bc8e3bd33 100644 --- a/docs/source/serving/deploying_with_docker.rst +++ b/docs/source/serving/deploying_with_docker.rst @@ -49,3 +49,6 @@ To run vLLM: --env "HUGGING_FACE_HUB_TOKEN=" \ vllm/vllm-openai +.. note:: + + **For `v0.4.1` and `v0.4.2` only** - the vLLM docker images under these versions are supposed to be run under the root user since a library under the root user's home directory, i.e. ``/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1`` is required to be loaded during runtime. If you are running the container under a different user, you may need to first change the permissions of the library (and all the parent directories) to allow the user to access it, then run vLLM with environment variable ``VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1`` . diff --git a/docs/source/serving/deploying_with_dstack.rst b/docs/source/serving/deploying_with_dstack.rst new file mode 100644 index 0000000000000..baf87314ca8e4 --- /dev/null +++ b/docs/source/serving/deploying_with_dstack.rst @@ -0,0 +1,103 @@ +.. _deploying_with_dstack: + +Deploying with dstack +============================ + +.. raw:: html + +

+ vLLM_plus_dstack +

+ +vLLM can be run on a cloud based GPU machine with `dstack `__, an open-source framework for running LLMs on any cloud. This tutorial assumes that you have already configured credentials, gateway, and GPU quotas on your cloud environment. + +To install dstack client, run: + +.. code-block:: console + + $ pip install "dstack[all] + $ dstack server + +Next, to configure your dstack project, run: + +.. code-block:: console + + $ mkdir -p vllm-dstack + $ cd vllm-dstack + $ dstack init + +Next, to provision a VM instance with LLM of your choice(`NousResearch/Llama-2-7b-chat-hf` for this example), create the following `serve.dstack.yml` file for the dstack `Service`: + +.. code-block:: yaml + + type: service + + python: "3.11" + env: + - MODEL=NousResearch/Llama-2-7b-chat-hf + port: 8000 + resources: + gpu: 24GB + commands: + - pip install vllm + - python -m vllm.entrypoints.openai.api_server --model $MODEL --port 8000 + model: + format: openai + type: chat + name: NousResearch/Llama-2-7b-chat-hf + +Then, run the following CLI for provisioning: + +.. code-block:: console + + $ dstack run . -f serve.dstack.yml + + ⠸ Getting run plan... + Configuration serve.dstack.yml + Project deep-diver-main + User deep-diver + Min resources 2..xCPU, 8GB.., 1xGPU (24GB) + Max price - + Max duration - + Spot policy auto + Retry policy no + + # BACKEND REGION INSTANCE RESOURCES SPOT PRICE + 1 gcp us-central1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + 2 gcp us-east1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + 3 gcp us-west1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + ... + Shown 3 of 193 offers, $5.876 max + + Continue? [y/n]: y + ⠙ Submitting run... + ⠏ Launching spicy-treefrog-1 (pulling) + spicy-treefrog-1 provisioning completed (running) + Service is published at ... + +After the provisioning, you can interact with the model by using the OpenAI SDK: + +.. code-block:: python + + from openai import OpenAI + + client = OpenAI( + base_url="https://gateway.", + api_key="" + ) + + completion = client.chat.completions.create( + model="NousResearch/Llama-2-7b-chat-hf", + messages=[ + { + "role": "user", + "content": "Compose a poem that explains the concept of recursion in programming.", + } + ] + ) + + print(completion.choices[0].message.content) + +.. note:: + + dstack automatically handles authentication on the gateway using dstack's tokens. Meanwhile, if you don't want to configure a gateway, you can provision dstack `Task` instead of `Service`. The `Task` is for development purpose only. If you want to know more about hands-on materials how to serve vLLM using dstack, check out `this repository `__ diff --git a/docs/source/serving/deploying_with_lws.rst b/docs/source/serving/deploying_with_lws.rst new file mode 100644 index 0000000000000..b63a432dde0d5 --- /dev/null +++ b/docs/source/serving/deploying_with_lws.rst @@ -0,0 +1,12 @@ +.. _deploying_with_lws: + +Deploying with LWS +============================ + +LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. +A major use case is for multi-host/multi-node distributed inference. + +vLLM can be deployed with `LWS `_ on Kubernetes for distributed model serving. + +Please see `this guide `_ for more details on +deploying vLLM on Kubernetes using LWS. diff --git a/docs/source/serving/env_vars.rst b/docs/source/serving/env_vars.rst new file mode 100644 index 0000000000000..0ce1374a3967b --- /dev/null +++ b/docs/source/serving/env_vars.rst @@ -0,0 +1,9 @@ +Environment Variables +======================== + +vLLM uses the following environment variables to configure the system: + +.. literalinclude:: ../../../vllm/envs.py + :language: python + :start-after: begin-env-vars-definition + :end-before: end-env-vars-definition diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index 93872397913e3..83a8b5a88bd38 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -8,4 +8,6 @@ Integrations deploying_with_kserve deploying_with_triton deploying_with_bentoml + deploying_with_lws + deploying_with_dstack serving_with_langchain diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 032fe5d03bd52..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,9 +16,8 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ] ) @@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ @@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12) +An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat @@ -110,5 +108,5 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser -:prog: vllm-openai-server +:prog: -m vllm.entrypoints.openai.api_server ``` \ No newline at end of file diff --git a/docs/source/serving/run_on_sky.rst b/docs/source/serving/run_on_sky.rst index 2c88d24dc5d0b..bd33c76cec3de 100644 --- a/docs/source/serving/run_on_sky.rst +++ b/docs/source/serving/run_on_sky.rst @@ -1,7 +1,7 @@ .. _on_cloud: -Running on clouds with SkyPilot -=============================== +Deploying and scaling up with SkyPilot +================================================ .. raw:: html @@ -9,51 +9,75 @@ Running on clouds with SkyPilot vLLM

-vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot `__, an open-source framework for running LLMs on any cloud. +vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with `SkyPilot `__, an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in `SkyPilot AI gallery `__. -To install SkyPilot and setup your cloud credentials, run: + +Prerequisites +------------- + +- Go to the `HuggingFace model page `__ and request access to the model :code:`meta-llama/Meta-Llama-3-8B-Instruct`. +- Check that you have installed SkyPilot (`docs `__). +- Check that :code:`sky check` shows clouds or Kubernetes are enabled. .. code-block:: console - $ pip install skypilot - $ sky check + pip install skypilot-nightly + sky check + + +Run on a single instance +------------------------ See the vLLM SkyPilot YAML for serving, `serving.yaml `__. .. code-block:: yaml resources: - accelerators: A100 + accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model. + use_spot: True + disk_size: 512 # Ensure model checkpoints can fit. + disk_tier: best + ports: 8081 # Expose to internet traffic. envs: - MODEL_NAME: decapoda-research/llama-13b-hf - TOKENIZER: hf-internal-testing/llama-tokenizer + MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct + HF_TOKEN: # Change to your own huggingface token, or use --env to pass. setup: | - conda create -n vllm python=3.9 -y + conda create -n vllm python=3.10 -y conda activate vllm - git clone https://github.com/vllm-project/vllm.git - cd vllm - pip install . - pip install gradio + + pip install vllm==0.4.0.post1 + # Install Gradio for web UI. + pip install gradio openai + pip install flash-attn==2.5.7 run: | conda activate vllm echo 'Starting vllm api server...' - python -u -m vllm.entrypoints.api_server \ - --model $MODEL_NAME \ - --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ - --tokenizer $TOKENIZER 2>&1 | tee api_server.log & + python -u -m vllm.entrypoints.openai.api_server \ + --port 8081 \ + --model $MODEL_NAME \ + --trust-remote-code \ + --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ + 2>&1 | tee api_server.log & + echo 'Waiting for vllm api server to start...' while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done + echo 'Starting gradio server...' - python vllm/examples/gradio_webserver.py + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://localhost:8081/v1 \ + --stop-token-ids 128009,128001 -Start the serving the LLaMA-13B model on an A100 GPU: +Start the serving the Llama-3 8B model on any of the candidate GPUs listed (L4, A10g, ...): .. code-block:: console - $ sky launch serving.yaml + HF_TOKEN="your-huggingface-token" sky launch serving.yaml --env HF_TOKEN Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion. @@ -61,9 +85,226 @@ Check the output of the command. There will be a shareable gradio link (like the (task, pid=7431) Running on public URL: https://.gradio.live -**Optional**: Serve the 65B model instead of the default 13B and use more GPU: +**Optional**: Serve the 70B model instead of the default 8B and use more GPU: + +.. code-block:: console + + HF_TOKEN="your-huggingface-token" sky launch serving.yaml --gpus A100:8 --env HF_TOKEN --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct + + +Scale up to multiple replicas +----------------------------- + +SkyPilot can scale up the service to multiple service replicas with built-in autoscaling, load-balancing and fault-tolerance. You can do it by adding a services section to the YAML file. + +.. code-block:: yaml + + service: + replicas: 2 + # An actual request for readiness probe. + readiness_probe: + path: /v1/chat/completions + post_data: + model: $MODEL_NAME + messages: + - role: user + content: Hello! What is your name? + max_tokens: 1 + +.. raw:: html + +
+ Click to see the full recipe YAML + + +.. code-block:: yaml + + service: + replicas: 2 + # An actual request for readiness probe. + readiness_probe: + path: /v1/chat/completions + post_data: + model: $MODEL_NAME + messages: + - role: user + content: Hello! What is your name? + max_tokens: 1 + + resources: + accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model. + use_spot: True + disk_size: 512 # Ensure model checkpoints can fit. + disk_tier: best + ports: 8081 # Expose to internet traffic. + + envs: + MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct + HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + + setup: | + conda create -n vllm python=3.10 -y + conda activate vllm + + pip install vllm==0.4.0.post1 + # Install Gradio for web UI. + pip install gradio openai + pip install flash-attn==2.5.7 + + run: | + conda activate vllm + echo 'Starting vllm api server...' + python -u -m vllm.entrypoints.openai.api_server \ + --port 8081 \ + --model $MODEL_NAME \ + --trust-remote-code \ + --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ + 2>&1 | tee api_server.log & + + echo 'Waiting for vllm api server to start...' + while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done + + echo 'Starting gradio server...' + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://localhost:8081/v1 \ + --stop-token-ids 128009,128001 + +.. raw:: html + +
+ +Start the serving the Llama-3 8B model on multiple replicas: + +.. code-block:: console + + HF_TOKEN="your-huggingface-token" sky serve up -n vllm serving.yaml --env HF_TOKEN + + +Wait until the service is ready: .. code-block:: console - sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf + watch -n10 sky serve status vllm + + +.. raw:: html + +
+ Example outputs: + +.. code-block:: console + + Services + NAME VERSION UPTIME STATUS REPLICAS ENDPOINT + vllm 1 35s READY 2/2 xx.yy.zz.100:30001 + + Service Replicas + SERVICE_NAME ID VERSION IP LAUNCHED RESOURCES STATUS REGION + vllm 1 1 xx.yy.zz.121 18 mins ago 1x GCP({'L4': 1}) READY us-east4 + vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP({'L4': 1}) READY us-east4 + +.. raw:: html + +
+ +After the service is READY, you can find a single endpoint for the service and access the service with the endpoint: + +.. code-block:: console + + ENDPOINT=$(sky serve status --endpoint 8081 vllm) + curl -L http://$ENDPOINT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Who are you?" + } + ], + "stop_token_ids": [128009, 128001] + }' + +To enable autoscaling, you could specify additional configs in `services`: + +.. code-block:: yaml + + services: + replica_policy: + min_replicas: 0 + max_replicas: 3 + target_qps_per_replica: 2 + +This will scale the service up to when the QPS exceeds 2 for each replica. + + +**Optional**: Connect a GUI to the endpoint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas. + +.. raw:: html + +
+ Click to see the full GUI YAML + +.. code-block:: yaml + + envs: + MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct + ENDPOINT: x.x.x.x:3031 # Address of the API server running vllm. + + resources: + cpus: 2 + + setup: | + conda activate vllm + if [ $? -ne 0 ]; then + conda create -n vllm python=3.10 -y + conda activate vllm + fi + + # Install Gradio for web UI. + pip install gradio openai + + run: | + conda activate vllm + export PATH=$PATH:/sbin + WORKER_IP=$(hostname -I | cut -d' ' -f1) + CONTROLLER_PORT=21001 + WORKER_PORT=21002 + + echo 'Starting gradio server...' + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://$ENDPOINT/v1 \ + --stop-token-ids 128009,128001 | tee ~/gradio.log + +.. raw:: html + +
+ +1. Start the chat web UI: + +.. code-block:: console + + sky launch -c gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint vllm) + + +2. Then, we can access the GUI at the returned gradio link: + +.. code-block:: console + + | INFO | stdout | Running on public URL: https://6141e84201ce0bb4ed.gradio.live + diff --git a/examples/aqlm_example.py b/examples/aqlm_example.py new file mode 100644 index 0000000000000..e7c17fa0362ae --- /dev/null +++ b/examples/aqlm_example.py @@ -0,0 +1,46 @@ +import argparse + +from vllm import LLM, SamplingParams + + +def main(): + + parser = argparse.ArgumentParser(description='AQLM examples') + + parser.add_argument('--model', + '-m', + type=str, + default=None, + help='model path, as for HF') + parser.add_argument('--choice', + '-c', + type=int, + default=0, + help='known good models by index, [0-4]') + parser.add_argument('--tensor_parallel_size', + '-t', + type=int, + default=1, + help='tensor parallel size') + + args = parser.parse_args() + + models = [ + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf", + "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf", + "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf", + ] + + model = LLM(args.model if args.model is not None else models[args.choice], + tensor_parallel_size=args.tensor_parallel_size) + + sampling_params = SamplingParams(max_tokens=100, temperature=0) + outputs = model.generate("Hello my name is", + sampling_params=sampling_params) + print(outputs[0].outputs[0].text) + + +if __name__ == '__main__': + main() diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py index 5e5b31265e3af..1eb961a5a76e3 100644 --- a/examples/fp8/extract_scales.py +++ b/examples/fp8/extract_scales.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.schema import QuantParamSchema -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py # The main differences are that we add the NPZ format and simplify # its functionality drastically for our purposes (e.g. we assume that # the quantized model exists locally and there is no need to download it) @@ -71,7 +71,7 @@ def _prepare_hf_weights( return hf_weights_files, use_safetensors -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py def _hf_tensorfile_iterator(filename: str, load_format: str, use_safetensors: bool): if load_format == "npz": diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654bf..60250c4303fbf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,11 +23,15 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) for o in outputs: generated_text = o.outputs[0].text print(generated_text) @@ -46,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") - - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + image = torch.load("images/stop_sign_image_features.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md new file mode 100644 index 0000000000000..75b4b31a80462 --- /dev/null +++ b/examples/logging_configuration.md @@ -0,0 +1,178 @@ +# Logging Configuration + +vLLM leverages Python's `logging.config.dictConfig` functionality to enable +robust and flexible configuration of the various loggers used by vLLM. + +vLLM offers two environment variables that can be used to accommodate a range +of logging configurations that range from simple-and-inflexible to +more-complex-and-more-flexible. + +- No vLLM logging (simple and inflexible) + - Set `VLLM_CONFIGURE_LOGGING=0` (leaving `VLLM_LOGGING_CONFIG_PATH` unset) +- vLLM's default logging configuration (simple and inflexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` +- Fine-grained custom logging configuration (more complex, more flexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and + set `VLLM_LOGGING_CONFIG_PATH=` + + +## Logging Configuration Environment Variables + +### `VLLM_CONFIGURE_LOGGING` + +`VLLM_CONFIGURE_LOGGING` controls whether or not vLLM takes any action to +configure the loggers used by vLLM. This functionality is enabled by default, +but can be disabled by setting `VLLM_CONFIGURE_LOGGING=0` when running vLLM. + +If `VLLM_CONFIGURE_LOGGING` is enabled and no value is given for +`VLLM_LOGGING_CONFIG_PATH`, vLLM will use built-in default configuration to +configure the root vLLM logger. By default, no other vLLM loggers are +configured and, as such, all vLLM loggers defer to the root vLLM logger to make +all logging decisions. + +If `VLLM_CONFIGURE_LOGGING` is disabled and a value is given for +`VLLM_LOGGING_CONFIG_PATH`, an error will occur while starting vLLM. + +### `VLLM_LOGGING_CONFIG_PATH` + +`VLLM_LOGGING_CONFIG_PATH` allows users to specify a path to a JSON file of +alternative, custom logging configuration that will be used instead of vLLM's +built-in default logging configuration. The logging configuration should be +provided in JSON format following the schema specified by Python's [logging +configuration dictionary +schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details). + +If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is +disabled, an error will occur while starting vLLM. + + +## Examples + +### Example 1: Customize vLLM root logger + +For this example, we will customize the vLLM root logger to use +[`python-json-logger`](https://github.com/madzak/python-json-logger) to log to +STDOUT of the console in JSON format with a log level of `INFO`. + +To begin, first, create an appropriate JSON logging configuration file: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "json": { + "class": "pythonjsonlogger.jsonlogger.JsonFormatter" + } + }, + "handlers": { + "console": { + "class" : "logging.StreamHandler", + "formatter": "json", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + } + }, + "version": 1 +} +``` + +Next, install the `python-json-logger` package if it's not already installed: + +```bash +pip install python-json-logger +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 2: Silence a particular vLLM logger + +To silence a particular vLLM logger, it is necessary to provide custom logging +configuration for the target logger that configures the logger so that it won't +propagate its log messages to the root vLLM logger. + +When custom configuration is provided for any logger, it is also necessary to +provide configuration for the root vLLM logger since any custom logger +configuration overrides the built-in default logging configuration used by vLLM. + +First, create an appropriate JSON logging configuration file that includes +configuration for the root vLLM logger and for the logger you wish to silence: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": "%m-%d %H:%M:%S", + "format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" + } + }, + "handlers": { + "vllm": { + "class" : "logging.StreamHandler", + "formatter": "vllm", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagage": false + }, + "vllm.example_noisy_logger": { + "propagate": false + } + }, + "version": 1 +} +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 3: Disable vLLM default logging configuration + +To disable vLLM's default logging configuration and silence all vLLM loggers, +simple set `VLLM_CONFIGURE_LOGGING=0` when running vLLM. This will prevent vLLM +for configuring the root vLLM logger, which in turn, silences all other vLLM +loggers. + +```bash +VLLM_CONFIGURE_LOGGING=0 \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +## Additional resources + +- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details) diff --git a/examples/offline_inference_arctic.py b/examples/offline_inference_arctic.py new file mode 100644 index 0000000000000..1fec3c99eb47c --- /dev/null +++ b/examples/offline_inference_arctic.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="snowflake/snowflake-arctic-instruct", + quantization="deepspeedfp", + tensor_parallel_size=8, + trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. + +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference_distributed.py b/examples/offline_inference_distributed.py index e4f085fa6665a..1e59e89509724 100644 --- a/examples/offline_inference_distributed.py +++ b/examples/offline_inference_distributed.py @@ -9,19 +9,31 @@ import numpy as np import ray +from packaging.version import Version +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM, SamplingParams +assert Version(ray.__version__) >= Version( + "2.22.0"), "Ray version must be at least 2.22.0" + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +# Set tensor parallelism per instance. +tensor_parallel_size = 1 + +# Set number of instances. Each instance will use tensor_parallel_size GPUs. +num_instances = 1 + # Create a class to do batch inference. class LLMPredictor: def __init__(self): # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") + self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + tensor_parallel_size=tensor_parallel_size) def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # Generate texts from the prompts. @@ -43,17 +55,41 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # from cloud storage (such as JSONL, Parquet, CSV, binary format). ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") + +# For tensor_parallel_size > 1, we need to create placement groups for vLLM +# to use. Every actor has to have its own placement group. +def scheduling_strategy_fn(): + # One bundle per tensor parallel worker + pg = ray.util.placement_group( + [{ + "GPU": 1, + "CPU": 1 + }] * tensor_parallel_size, + strategy="STRICT_PACK", + ) + return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True)) + + +resources_kwarg = {} +if tensor_parallel_size == 1: + # For tensor_parallel_size == 1, we simply set num_gpus=1. + resources_kwarg["num_gpus"] = 1 +else: + # Otherwise, we have to set num_gpus=0 and provide + # a function that will create a placement group for + # each instance. + resources_kwarg["num_gpus"] = 0 + resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn + # Apply batch inference for all input data. ds = ds.map_batches( LLMPredictor, # Set the concurrency to the number of LLM instances. - concurrency=10, - # Specify the number of GPUs required per LLM instance. - # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism - # (i.e., `tensor_parallel_size`). - num_gpus=1, + concurrency=num_instances, # Specify the batch size for inference. batch_size=32, + **resources_kwarg, ) # Peek first 10 results. diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py new file mode 100644 index 0000000000000..7d5ef128bc8e0 --- /dev/null +++ b/examples/offline_inference_embedding.py @@ -0,0 +1,17 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.encode(prompts) +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 4096 floats diff --git a/examples/offline_inference_openai.md b/examples/offline_inference_openai.md new file mode 100644 index 0000000000000..40462ce1eb78c --- /dev/null +++ b/examples/offline_inference_openai.md @@ -0,0 +1,172 @@ +# Offline Inference with the OpenAI Batch file format + + **NOTE:** This is a guide to performing batch inference using the OpenAI batch file format, **NOT** the complete Batch (REST) API. + + ## File Format + + The OpenAI batch file format consists of a series of json objects on new lines. + + [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl) + + Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. + + **NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon). + + ## Pre-requisites + +* Ensure you are using `vllm >= 0.4.3`. You can check by running `python -c "import vllm; print(vllm.__version__)"`. +* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. + - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) + - Install the token on your machine (Run `huggingface-cli login`). + - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. + + + ## Example: Running with a local file + + ### Step 1: Create your batch file + + To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + + ### Step 2: Run the batch + +The batch running tool is designed to be used from the command line. + +You can run the batch with the following command, which will write its results to a file called `results.jsonl` + +``` +python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 3: Check your results + +You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` + +``` +$ cat ../results.jsonl +{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} +{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} +``` + +## Example 2: Using remote files + +The batch runner supports remote input and output urls that are accessible via http/https. + +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl`, you can run + +``` +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +## Example 3: Integrating with AWS S3 + +To integrate with cloud blob storage, we recommend using presigned urls. + +[Learn more about S3 presigned urls here] + +### Additional prerequisites + +* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). +* The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3. + - [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html). +* The `boto3` python package (Run `pip install boto3`) to generate presigned urls. + +### Step 1: Upload your input script + +To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + +Now upload your batch file to your S3 bucket. + +``` +aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +``` + + +### Step 2: Generate your presigned urls + +Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. + +(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py) + +``` +import boto3 +from botocore.exceptions import ClientError + +def generate_presigned_url(s3_client, client_method, method_parameters, expires_in): + """ + Generate a presigned Amazon S3 URL that can be used to perform an action. + + :param s3_client: A Boto3 Amazon S3 client. + :param client_method: The name of the client method that the URL performs. + :param method_parameters: The parameters of the specified client method. + :param expires_in: The number of seconds the presigned URL is valid for. + :return: The presigned URL. + """ + try: + url = s3_client.generate_presigned_url( + ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in + ) + except ClientError: + raise + return url + + +s3_client = boto3.client("s3") +input_url = generate_presigned_url( + s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 +) +output_url = generate_presigned_url( + s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 +) +print(f"{input_url=}") +print(f"{output_url=}") +``` + +This script should output + +``` +input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +``` + +### Step 3: Run the batch runner using your presigned urls + +You can now run the batch runner, using the urls generated in the previous section. + +``` +python -m vllm.entrypoints.openai.run_batch \ + -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + --model --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 4: View your results + +Your results are now on S3. You can view them in your terminal by running + +``` +aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - +``` diff --git a/examples/openai_chatcompletion_client.py b/examples/openai_chat_completion_client.py similarity index 100% rename from examples/openai_chatcompletion_client.py rename to examples/openai_chat_completion_client.py diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py new file mode 100644 index 0000000000000..b73360fe15a24 --- /dev/null +++ b/examples/openai_embedding_client.py @@ -0,0 +1,23 @@ +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +responses = client.embeddings.create(input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" +], + model=model) + +for data in responses.data: + print(data.embedding) # list of float of len 4096 diff --git a/examples/openai_example_batch.jsonl b/examples/openai_example_batch.jsonl new file mode 100644 index 0000000000000..5aa7e185c180a --- /dev/null +++ b/examples/openai_example_batch.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} diff --git a/examples/production_monitoring/README.md b/examples/production_monitoring/README.md index 29b611caeda23..268f2e771018f 100644 --- a/examples/production_monitoring/README.md +++ b/examples/production_monitoring/README.md @@ -29,7 +29,8 @@ python3 ../../benchmarks/benchmark_serving.py \ --model mistralai/Mistral-7B-v0.1 \ --tokenizer mistralai/Mistral-7B-v0.1 \ --endpoint /v1/completions \ - --dataset ShareGPT_V3_unfiltered_cleaned_split.json \ + --dataset-name sharegpt \ + --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --request-rate 3.0 ``` diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 071f134c6e5e0..273f7f5ac42cf 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -1,4 +1,41 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__elements": {}, + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "10.4.2" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], "annotations": { "list": [ { @@ -25,14 +62,14 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 29, + "id": null, "links": [], "liveNow": false, "panels": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "End to end request latency measured in seconds.", "fieldConfig": { @@ -41,6 +78,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -54,6 +92,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -111,7 +150,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -127,7 +166,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -144,7 +183,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -161,7 +200,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -178,7 +217,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:e2e_request_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:e2e_request_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -195,7 +234,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of tokens processed per second", "fieldConfig": { @@ -204,6 +243,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -217,6 +257,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -273,7 +314,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -289,7 +330,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -310,7 +351,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Inter token latency in seconds.", "fieldConfig": { @@ -319,6 +360,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -332,6 +374,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -389,7 +432,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -405,7 +448,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -422,7 +465,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -439,7 +482,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -456,7 +499,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -473,7 +516,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of requests in RUNNING, WAITING, and SWAPPED state", "fieldConfig": { @@ -482,6 +525,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -495,6 +539,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -552,7 +597,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -568,7 +613,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -585,7 +630,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -606,7 +651,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "P50, P90, P95, and P99 TTFT latency in seconds.", "fieldConfig": { @@ -615,6 +660,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -628,6 +674,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -685,7 +732,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -702,7 +749,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -718,7 +765,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -735,7 +782,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -752,7 +799,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_to_first_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_to_first_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -769,7 +816,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Percentage of used cache blocks by vLLM.", "fieldConfig": { @@ -778,6 +825,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -791,6 +839,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -848,7 +897,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:gpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -860,7 +909,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:cpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -873,23 +922,303 @@ ], "title": "Cache Utilization", "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Heatmap of request prompt length", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 12, + "options": { + "calculate": false, + "cellGap": 1, + "cellValues": { + "unit": "none" + }, + "color": { + "exponent": 0.5, + "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": true + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Prompt Length", + "axisPlacement": "left", + "reverse": false, + "unit": "none" + } + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Prompt Length", + "type": "heatmap" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Heatmap of request generation length", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 13, + "options": { + "calculate": false, + "cellGap": 1, + "cellValues": { + "unit": "none" + }, + "color": { + "exponent": 0.5, + "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": true + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Generation Length", + "axisPlacement": "left", + "reverse": false, + "unit": "none" + } + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Generation Length", + "type": "heatmap" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "interval": "", + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Finish Reason", + "type": "timeseries" } ], "refresh": "", - "schemaVersion": 37, - "style": "dark", + "schemaVersion": 39, "tags": [], "templating": { "list": [ { - "current": { - "selected": false, - "text": "vllm", - "value": "vllm" - }, + "current": {}, "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(model_name)", "hide": 0, @@ -918,6 +1247,6 @@ "timezone": "", "title": "vLLM", "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b", - "version": 2, + "version": 1, "weekStart": "" } diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py new file mode 100644 index 0000000000000..c595d98ba2750 --- /dev/null +++ b/examples/save_sharded_state.py @@ -0,0 +1,75 @@ +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") +parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") +parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = LLM(**dataclasses.asdict(engine_args)) + # Prepare output directory + Path(args.output).mkdir(exist_ok=True) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + # Copy metadata files to output directory + for file in os.listdir(model_path): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree(os.path.join(model_path, file), + os.path.join(args.output, file)) + else: + shutil.copy(os.path.join(model_path, file), args.output) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py new file mode 100644 index 0000000000000..8b74ae1d75a1d --- /dev/null +++ b/examples/tensorize_vllm_model.py @@ -0,0 +1,244 @@ +import argparse +import dataclasses +import json +import os +import uuid +from functools import partial + +from tensorizer import stream_io + +from vllm import LLM +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, + TensorizerConfig, + serialize_vllm_model) + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer +to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, +or locally. Tensor encryption and decryption is also supported, although +libsodium must be installed to use it. Install vllm with tensorizer support +using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit +https://github.com/coreweave/tensorizer + +To serialize a model, install vLLM from source, then run something +like this from the root level of this repository: + +python -m examples.tensorize_vllm_model \ + --model facebook/opt-125m \ + serialize \ + --serialized-directory s3://my-bucket \ + --suffix v1 + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. This +assumes your S3 credentials are specified as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and +`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide +`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` +as CLI args to this script. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this from the root +level of this repository: + +python -m examples.tensorize_vllm_model \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments for serializing, run +`python -m examples.tensorize_vllm_model serialize --help`. + +Or for deserializing: + +`python -m examples.tensorize_vllm_model deserialize --help`. + +Once a model is serialized, tensorizer can be invoked with the `LLM` class +directly to load models: + + llm = LLM(model="facebook/opt-125m", + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri = path_to_tensors, + num_readers=3, + ) + ) + +A serialized model can be used during model loading for the vLLM OpenAI +inference server. `model_loader_extra_config` is exposed as the CLI arg +`--model-loader-extra-config`, and accepts a JSON string literal of the +TensorizerConfig arguments desired. + +In order to see all of the available arguments usable to configure +loading with tensorizer that are given to `TensorizerConfig`, run: + +`python -m examples.tensorize_vllm_model deserialize --help` + +under the `tensorizer options` section. These can also be used for +deserialization in this example script, although `--tensorizer-uri` and +`--path-to-tensors` are functionally the same in this case. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True, + help="The directory to serialize the model to. " + "This can be a local directory or S3 URI. The path to where the " + "tensors are saved is a combination of the supplied `dir` and model " + "reference ID. For instance, if `dir` is the serialized directory, " + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " + "where `suffix` is given by `--suffix` or a random UUID if not " + "provided.") + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + TensorizerArgs.add_cli_args(deserialize_parser) + + return parser.parse_args() + + + +def deserialize(): + llm = LLM(model=args.model, + load_format="tensorizer", + model_loader_extra_config=tensorizer_config + ) + return llm + + + +args = parse_args() + +s3_access_key_id = (getattr(args, 's3_access_key_id', None) + or os.environ.get("S3_ACCESS_KEY_ID", None)) +s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) + or os.environ.get("S3_SECRET_ACCESS_KEY", None)) +s3_endpoint = (getattr(args, 's3_endpoint', None) + or os.environ.get("S3_ENDPOINT_URL", None)) + +credentials = { + "s3_access_key_id": s3_access_key_id, + "s3_secret_access_key": s3_secret_access_key, + "s3_endpoint": s3_endpoint +} + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +init_distributed_environment(world_size=1, rank=0, local_rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + + +if args.model_loader_extra_config: + config = json.loads(args.model_loader_extra_config) + tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args() + tensorizer_args.tensorizer_uri = args.path_to_tensors +else: + tensorizer_args = None + +if args.command == "serialize": + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + tensorizer_config = TensorizerConfig( + tensorizer_uri=model_path, + **credentials) + serialize_vllm_model(engine, tensorizer_config, keyfile) +elif args.command == "deserialize": + if not tensorizer_args: + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + encryption_keyfile = keyfile, + **credentials + ) + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/format.sh b/format.sh index deb57b2b049d1..d110855f8c273 100755 --- a/format.sh +++ b/format.sh @@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) +CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' @@ -93,12 +95,29 @@ fi echo 'vLLM yapf: Done' # Run mypy -# TODO(zhuohan): Enable mypy -# echo 'vLLM mypy:' -# mypy - +echo 'vLLM mypy:' +mypy vllm/attention --config-file pyproject.toml +mypy vllm/core --config-file pyproject.toml +mypy vllm/distributed --config-file pyproject.toml +mypy vllm/entrypoints --config-file pyproject.toml +mypy vllm/executor --config-file pyproject.toml +mypy vllm/usage --config-file pyproject.toml +mypy vllm/*.py --config-file pyproject.toml +mypy vllm/transformers_utils --config-file pyproject.toml +mypy vllm/engine --config-file pyproject.toml +mypy vllm/worker --config-file pyproject.toml +mypy vllm/spec_decode --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml +mypy vllm/lora --config-file pyproject.toml +mypy vllm/logging --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml + + +# If git diff returns a file that is in the skip list, the file may be checked anyway: +# https://github.com/codespell-project/codespell/issues/1915 +# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem CODESPELL_EXCLUDES=( - '--skip' '*docs/source/_build/**' + '--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**' ) # check spelling of specified files @@ -119,10 +138,9 @@ spell_check_changed() { # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that # exist on both branches. MERGEBASE="$(git merge-base origin/main HEAD)" - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell "${CODESPELL_EXCLUDES[@]}" + codespell "${CODESPELL_EXCLUDES[@]}" fi } @@ -166,7 +184,6 @@ lint_changed() { } # Run Ruff -echo 'vLLM ruff:' ### This flag lints individual files. --files *must* be the first command line ### arg to use this option. if [[ "$1" == '--files' ]]; then @@ -179,6 +196,7 @@ else # Format only the files that changed in last commit. lint_changed fi +echo 'vLLM ruff: Done' # check spelling of specified files isort_check() { @@ -220,6 +238,59 @@ else fi echo 'vLLM isort: Done' +# Clang-format section +# Exclude some files for formatting because they are vendored +# NOTE: Keep up to date with .github/workflows/clang-format.yml +CLANG_FORMAT_EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' +) + +# Format specified files with clang-format +clang_format() { + clang-format -i "$@" +} + +# Format files that differ from main branch with clang-format. +clang_format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause clang-format to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + # Get the list of changed files, excluding the specified ones + changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}")) + if [ -n "$changed_files" ]; then + echo "$changed_files" | xargs -P 5 clang-format -i + fi +} + +# Format all files with clang-format +clang_format_all() { + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \ + | xargs clang-format -i +} + +# Run clang-format +if [[ "$1" == '--files' ]]; then + clang_format "${@:2}" +elif [[ "$1" == '--all' ]]; then + clang_format_all +else + clang_format_changed +fi +echo 'vLLM clang-format: Done' + + if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' @@ -228,5 +299,3 @@ if ! git diff --quiet &>/dev/null; then exit 1 fi - - diff --git a/gradlib_fp8/csrc/hipbsolgemm.cu b/gradlib_fp8/csrc/hipbsolgemm.cu deleted file mode 100644 index 4d6ca59520dc0..0000000000000 --- a/gradlib_fp8/csrc/hipbsolgemm.cu +++ /dev/null @@ -1,673 +0,0 @@ -// #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif - -#include -#include -#include -#include -#include -#include -#include -#include -// #include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include "nvToolsExt.h" - -//#include - - -// #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif - -// #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL -// #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif - -#ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } -#endif - -#ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } -#endif - -namespace { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; - - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - size_t workspace_size = 2*128*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipDataType dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; - - // std::map, std::vector> heuristic_map; - std::map heuristic_map; - - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; - - bool cout_print = false; - - torch::Tensor dTensor; - - //std::vector heuristicResult; -} - -//find all hipblaslt solutions for given gemm problem -std::vector hipblasLtMatmul_findallsols_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType intype, - hipDataType outtype, - hipStream_t &stream) -{ - int flag { 0 }; - hipblasLtMatrixLayout_t matA, matB, matC; - hipblasLtMatmulDesc_t matmul; - if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); - } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); - } - if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); - } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); - } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - - //std::vector heuristicResult(10); - //CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - // handle, matmul, matA, matB, matC, matC, - // preference, 10, heuristicResult.data(), &returnedAlgoCount)); - std::vector heuristicResult; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - op_A, - op_B, - intype, - intype, - outtype, - outtype, - HIPBLAS_COMPUTE_32F, - heuristicResult)); - - std::vector algoIndex; - int returned_algo_count = heuristicResult.size(); - //for (int i = 0; i < returnedAlgoCount; i++) { - for (int i = 0; i < returned_algo_count; i++) { - auto algo = heuristicResult[i].algo; - size_t ret_workspace_size = 0; - auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - alpha, - matA, - matB, - beta, - matC, - matC, - algo, - ret_workspace_size - ); - if (status == HIPBLAS_STATUS_SUCCESS) { - if (ret_workspace_size heuristicResult(1); - if (solution_index<0) { - //nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; - if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << intype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; - } - //std::vector heuristicResult(request_solutions); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { - std::cout << "less solution found! request: " << request_solutions - << ", found: " << returnedAlgoCount << std::endl; - } - //heuristic_map[gemm_key] = heuristicResult[0]; -/* - if (returnedAlgoCount == 1) { - heuristic_map[gemm_key] = heuristicResult[0]; - } else { - // benchmark requested solutions and pick best one - int bestIndex { -1 }; - double bestMs { std::numeric_limits::max() }; - for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { - // warm up - for (int iter { 0 }; iter < warmup_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter { 0 }; iter < bench_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; - } - if (bestMs > eventMs) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) { - std::cout << " *" << std::endl; - } - } else { - if (cout_print) { - std::cout << std::endl; - } - } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; - } -*/ - //nvtxRangePop(); - } else { - std::vector algoIndex(1); - algoIndex[0]=solution_index; - //std::vector tmpAlgo; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); - } - - //size_t ret_workspace_size = 0; - - //auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - // alpha, - // matA, - // matB, - // beta, - // matC, - // matC, - // heuristicResult[0].algo, - // ret_workspace_size - //); - //if (status1 == HIPBLAS_STATUS_SUCCESS) { - // std::cout << "Workspace size" << ret_workspace_size << std::endl; - - //} else { - // std::cout << "Algo not supported!!!" << std::endl; - - //} - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristicResult[0].algo, - d_workspace, workspace_size, - stream); - - //nvtxRangePushA("hipBLASLt variables deletion"); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - //nvtxRangePop(); - - return status; -} -///////////////////////////////////////////////////////////////////////////////////////////////////////// -torch::Tensor HipbSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - const int solution_index, - at::optional Type = at::nullopt, - at::optional scale1 = at::nullopt, - at::optional scale2 = at::nullopt, - at::optional scaleOut = at::nullopt - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto inType { mat1.options().dtype() }; - auto outType = inType.toScalarType(); - if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); - auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; - - bool transpose_result = true; - bool transpose_mat1; - bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { - transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { - transpose_mat2 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { - transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { - transpose_mat1 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_mat1; - transpose_mat1 = !transpose_mat2; - transpose_mat2 = !tmp; - mat1_strides = mat2.strides(); - mat2_strides = mat1.strides(); - mat1_sizes = mat2.sizes(); - mat2_sizes = mat1.sizes(); - } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; - int64_t m = mat1_sizes[transpose_result ? 1 : 0]; - int64_t k = mat1_sizes[transpose_result ? 0 : 1]; - int64_t n = mat2_sizes[transpose_result ? 0 : 1]; - int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; - int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; - int64_t result_ld = result.stride(transpose_result ? 0 : 1); - - void * d_scale1 = nullptr, * d_scale2 = nullptr, * d_scaleOut = nullptr; - if (scale1.has_value()) { - d_scale1 = static_cast(scale1.value().data_ptr()); - } - if (scale2.has_value()) { - d_scale2 = static_cast(scale2.value().data_ptr()); - } - if (scaleOut.has_value()) { - d_scaleOut = static_cast(scaleOut.value().data_ptr()); - } - - - hipDataType hipblasInType, hipblasOutType; - if (inType == at::kHalf) { - hipblasInType = HIP_R_16F; - } else if (inType == at::kBFloat16) { - hipblasInType = HIP_R_16BF; - } else if (inType == at::kFloat) { - hipblasInType = HIP_R_32F; - } else if (inType == at::kFloat8_e4m3fnuz) { - hipblasInType = HIP_R_8F_E4M3_FNUZ; - } else { - assert(false && "Wrong datatype!"); - } - - if (outType == at::kHalf) { - hipblasOutType = HIP_R_16F; - } else if (outType == at::kBFloat16) { - hipblasOutType = HIP_R_16BF; - } else if (outType == at::kFloat) { - hipblasOutType = HIP_R_32F; - } else if (outType == at::kFloat8_e4m3fnuz) { - hipblasOutType = HIP_R_8F_E4M3_FNUZ; - } else { - assert(false && "Wrong datatype!"); - } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - if (transpose_result) std::swap(d_scale1, d_scale2); - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - - CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, d_scale1, - ptrB, mat2_ld, d_scale2, - &zero, - ptrC, result_ld, d_scaleOut, - hipblasInType, - hipblasOutType, - current_stream,solution_index)); - - return result; -} - -//find all hipblas solutions and return them to python land -std::vector HipbFindAllSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - at::optional Type = at::nullopt - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto inType { mat1.options().dtype() }; - auto outType = inType.toScalarType(); - if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); - auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - bool transpose_result = true; - bool transpose_mat1; - bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { - transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { - transpose_mat2 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { - transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { - transpose_mat1 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - if (transpose_result) { - bool tmp = transpose_mat1; - transpose_mat1 = !transpose_mat2; - transpose_mat2 = !tmp; - mat1_strides = mat2.strides(); - mat2_strides = mat1.strides(); - mat1_sizes = mat2.sizes(); - mat2_sizes = mat1.sizes(); - } - float one { 1.0f }; - float zero { 0.0f }; - int64_t m = mat1_sizes[transpose_result ? 1 : 0]; - int64_t k = mat1_sizes[transpose_result ? 0 : 1]; - int64_t n = mat2_sizes[transpose_result ? 0 : 1]; - int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; - int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; - int64_t result_ld = result.stride(transpose_result ? 0 : 1); - hipDataType hipblasInType, hipblasOutType; - if (inType == at::kHalf) { - hipblasInType = HIP_R_16F; - } else if (inType == at::kBFloat16) { - hipblasInType = HIP_R_16BF; - } else if (inType == at::kFloat) { - hipblasInType = HIP_R_32F; - } else if (inType == at::kFloat8_e4m3fnuz) { - hipblasInType = HIP_R_8F_E4M3_FNUZ; - } else { - assert(false && "Wrong datatype!"); - } - if (outType == at::kHalf) { - hipblasOutType = HIP_R_16F; - } else if (outType == at::kBFloat16) { - hipblasOutType = HIP_R_16BF; - } else if (outType == at::kFloat) { - hipblasOutType = HIP_R_32F; - } else if (outType == at::kFloat8_e4m3fnuz) { - hipblasOutType = HIP_R_8F_E4M3_FNUZ; - } else { - assert(false && "Wrong datatype!"); - } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - - return hipblasLtMatmul_findallsols_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasInType, - hipblasOutType, - current_stream); - -} -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -void hipb_create_extension() -{ - //CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - //CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); - CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - //CHECK_HIP_ERROR(hipEventCreate(&start)); - //CHECK_HIP_ERROR(hipEventCreate(&stop)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -void hipb_destroy_extension() -{ - //CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - //CHECK_HIP_ERROR(hipEventDestroy(event)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); - - //CHECK_HIP_ERROR(hipEventDestroy(start)); - //CHECK_HIP_ERROR(hipEventDestroy(stop)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); - m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &HipbSolIdxBlas, "mm", py::arg("mat1"), py::arg("mat2"), py::arg("solution_index"), py::arg("outType")=at::nullopt, py::arg("scale1")=at::nullopt, py::arg("scale2")=at::nullopt, py::arg("scaleOut")=at::nullopt); - m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", py::arg("mat1"), py::arg("mat2"), py::arg("outType")=at::nullopt); -} diff --git a/gradlib_fp8/gemm_runner.py b/gradlib_fp8/gemm_runner.py deleted file mode 100644 index 683faf7c936e0..0000000000000 --- a/gradlib_fp8/gemm_runner.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import hipbsolidxgemm -import numpy as np -import torch.nn.functional as F -import sys -import pandas as pd -import timeit - -hipbsolidxgemm.hipb_create_extension() - -class TunedGemm: - def __init__(self,tuned_csv_file): - self.bestsols = pd.read_csv(tuned_csv_file,index_col=[0]) - self.create_ds() - def create_ds(self): - df = self.bestsols - solds = {} - for i in range(len(df)): - ds = df.iloc[i] - key = (ds['M'],ds['N'],ds['K']) - solds[key] = (int(ds['solidx'])) - #print(solds) - self.solids = solds - def query_sol(self,m,n,k): - return self.solids.get((m,n,k),(0,0)) - def mm(self,inp,weights): - soltype,solidx = self.query_sol(m=weights.shape[0],n=inp.shape[0],k=inp.shape[1]) - if soltype==1: - out = hipbsolidxgemm.hipb_mm(inp,weights.t(),solidx) - else: - out = F.linear(inp,weights) - return out - def run_all_tuned_sols(self): - for i in range(len(self.bestsols)): - ds = self.bestsols.iloc[i] - print('>>> Running tuned solution') - print(ds) - inp = torch.randn((ds['N'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda') - weights = torch.randn((ds['M'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda') - self.mm(inp,weights) - -def get_dtype(dtype_csv): - if dtype_csv=='torch.float16': - dtype = torch.float16 - elif dtype_csv=='torch.bfloat16': - dtype = torch.bfloat16 - elif dtype_csv=='torch.float32': - dtype = torch.float32 - elif dtype_csv=='torch.float32': - dtype = torch.float32 - return dtype - -if __name__ == '__main__': - tgemm = TunedGemm(sys.argv[1]) #csv file with tuned sols goes in argv[1] - print(tgemm.bestsols) - tgemm.run_all_tuned_sols() - - diff --git a/gradlib_fp8/gemm_tuner.py b/gradlib_fp8/gemm_tuner.py deleted file mode 100644 index 7cf0614aeb840..0000000000000 --- a/gradlib_fp8/gemm_tuner.py +++ /dev/null @@ -1,93 +0,0 @@ -import torch -import os -import argparse -from gradlib.GemmTuner import GemmTuner -import hipbsolidxgemm -import numpy as np -import torch.nn.functional as F -import sys -import pandas as pd -import json -import random -from pathlib import Path -hipbsolidxgemm.hipb_create_extension() - -''' -{'architectures': ['LlamaForCausalLM'], 'bos_token_id': 1, 'eos_token_id': 2, 'hidden_act': 'silu', 'hidden_size': 5120, 'initializer_range': 0.02, -'intermediate_size': 13824, 'max_position_embeddings': 2048, 'model_type': 'llama', 'num_attention_heads': 40, 'num_hidden_layers': 40, 'num_key_value_heads': 40, -'pretraining_tp': 1, 'rms_norm_eps': 1e-05, 'rope_scaling': None, 'tie_word_embeddings': False, 'torch_dtype': 'float16', 'transformers_version': '4.33.0.dev0', 'use_cache': True, 'vocab_size': 32000} -''' -def generate_mk_sets(model_dir, tp=1): - f = open(f'{model_dir}/config.json') - data = json.load(f) - hidden_size = data['hidden_size'] - intermediate_size = data['intermediate_size'] - total_num_heads = data['num_attention_heads'] - total_num_kv_heads = data['num_key_value_heads'] - head_dim = hidden_size // total_num_heads - return [((total_num_heads + (2*total_num_kv_heads)) * head_dim // tp, hidden_size), (hidden_size, hidden_size // tp), (intermediate_size *2 // tp, hidden_size), (hidden_size, intermediate_size // tp) ], hidden_size - -def get_dtype(dtype_str): - dtype = torch.float16 - if dtype_str == 'f32': - dtype = torch.float32 - elif dtype_str == 'bf16': - dtype = torch.bfloat16 - elif dtype_str == 'f16': - dtype = torch.float16 - elif dtype_str == 'f8': - dtype = torch.float8_e4m3fnuz - else: - print('>>> Warning! Invalid dtype', dtype_str, 'using default dtype f16') - return dtype - - -def list_of_ints(arg): - return list(map(int, arg.split(','))) - -def load_input_gemms(input_file): - if Path(input_file).is_file(): - return - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--model_dir", type=str, default=os.getenv('GTUNE_MODEL', ""), help="Enter the location of your model directory") - parser.add_argument("--tuned_file", type=str, default=os.getenv('GTUNE_TUNED', "tuned.csv"), help="output file for tuned gemm solutions") - parser.add_argument("--input_file", type=str, default=os.getenv('GTUNE_INPUT', None), help="list of gemms to tune for, mutually exclusive with model_dir") - parser.add_argument("--tp", type=int, default=os.getenv('GTUNE_TP', 1), help="Tensor parallelism to be used.") - parser.add_argument("--indtype", type=str, default='f8', help="dtype f32 f16 bf16") - parser.add_argument("--outdtype", type=str, default='f16', help="dtype f32 f16 bf16") - parser.add_argument("--batch_size", type=int, default=os.getenv('GTUNE_BATCH_SIZE', 1), help="Batch size to tune for") - parser.add_argument("--nsets", type=list_of_ints, default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], help="N sizes to tune for: 1,128,2048") - args = parser.parse_args() - - indtype = get_dtype(args.indtype) - outdtype = get_dtype(args.outdtype) - - gtuner = GemmTuner(indtype, outdtype, args.tuned_file) - nsets = [i * args.batch_size for i in args.nsets] - if args.input_file: - print(f">>> Loading {args.input_file}") - if not Path(args.input_file).is_file(): - print(f">>> ERROR: {args.input_file} does not exist. Exiting") - exit(1) - shapes = pd.read_csv(args.input_file) - for i in range(len(shapes)): - ds = shapes.iloc[i] - gtuner.add_gemm(ds['M'],ds['N'],ds['K']) - else: - if not args.model_dir: - print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") - #LL2 13B sizes - mksets = [(15360, 5120), (5120, 5120), (27648, 5120), (5120, 13824)] - gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm - else: - mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) - gtuner.add_gemm(m=32000//args.tp, n=1 * args.batch_size, k=hidden_size) #TODO: Handle cases where vocab_size is not divisible by tp - - for n in sorted(nsets): - for m, k in mksets: - gtuner.add_gemm(m, n, k) - - gtuner.find_best_sols() diff --git a/gradlib_fp8/gradlib/GemmTuner.py b/gradlib_fp8/gradlib/GemmTuner.py deleted file mode 100644 index 55071228fb32d..0000000000000 --- a/gradlib_fp8/gradlib/GemmTuner.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -import os -import argparse -import hipbsolidxgemm -import numpy as np -import torch.nn.functional as F -import sys -import pandas as pd -import json -import random -from pathlib import Path - -hipbsolidxgemm.hipb_create_extension() - -rtol = 1e-5 -atol = 1 -dtype = torch.float16 - -class Gemm: - def __init__(self,m,n,k,indtype,outdtype): - self.m=m - self.k=k - self.n=n - self.indtype=indtype - self.outdtype=outdtype - self.nb = 37 - self.inp = torch.randn((self.n, self.k), device='cuda').to(self.indtype) - self.weights = torch.randn((self.m, self.k), device='cuda').to(self.indtype) - #weights2 is used in measurement/warm iters to ensure HBM fetch for weight tensors - self.weights2 = torch.randn((self.nb, self.m, self.k), device='cuda').to(self.indtype) - self.blob = torch.ones(128*1024*1024, dtype=torch.float32, device='cuda') - self.topn = 20 #number of top solutions from each source - self.hipb_sols=[] - self.rtol = 1e-5 - self.atol = 1 - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - - - def find_hipblas_sols(self): - sols = hipbsolidxgemm.hipb_findallsols(self.inp,self.weights.t(),self.outdtype) - print('M N K',self.m,self.n,self.k,'>>> Total hipb solutions',len(sols), flush=True) - #print(sols) - self.hipb_sols = sols - - - def check_gemm_ref(self,libtype,solidx): - ref = F.linear(self.inp.to(torch.float32),self.weights.to(torch.float32)).to(self.outdtype) - c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) - if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): - #print('>>>',libtype,'Solidx',solidx,'passed reference test') - return True - else: - print('>>>','Solidx',solidx,'FAILED reference test', flush=True) - print(ref, flush=True) - print(c, flush=True) - return False - def hipb_time_sol(self,solidx,cold_iters=2,warm_iters=10): - #print('>>>hipbtime',solidx) - for i in range(cold_iters): - c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) - self.start.record() - for i in range(warm_iters): - c = hipbsolidxgemm.hipb_mm(self.inp,self.weights2 [random.randint(0,self.nb-1)].t(),solidx,self.outdtype) - self.end.record() - torch.cuda.synchronize() - gtime = self.start.elapsed_time(self.end)/warm_iters - #print('>>> Solidx GTime',solidx,gtime,'ms') - return gtime - def hipb_time_all_sols(self,fast_mode=0,top_sols=0): - coldi=20; warmi=20 - if fast_mode: coldi=2; warmi=2 - solutions = self.hipb_sols - if top_sols: solutions = self.hipb_top_sols - gtimes = {} - for solidx in solutions: - gtimes[solidx] = self.hipb_time_sol(solidx, cold_iters=coldi, warm_iters=warmi) - self.hipb_gtimedf = pd.DataFrame.from_dict(gtimes,orient='index',columns=['gtimems']).sort_values(by='gtimems') - self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') - print('>>> HipBlasLt top solutions, Fast Mode',fast_mode) - print(self.hipb_gtimedf.head(self.topn)) - def warmup(self,warmi=500): - for i in range(warmi): - self.blob = self.blob + 0.00001 - def functional_check_topn_fastest(self): - hipb_topn = [] - for solidx in self.hipb_gtimedf.index[:self.topn]: - if self.check_gemm_ref(libtype='hipblaslt',solidx=solidx): - hipb_topn.append(solidx) - self.hipb_top_sols = hipb_topn - - def find_fastest_solution(self): - self.find_hipblas_sols() - self.warmup() - self.hipb_time_all_sols(fast_mode=1) - self.functional_check_topn_fastest() - self.warmup() - self.hipb_time_all_sols(fast_mode=0,top_sols=1) - if len(self.hipb_gtimedf)>0: - best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] - self.best_solidx = self.hipb_gtimedf.index[0] - self.best_soltime = best_hipb_time - else: - print('>>> No hipblas solutions found!',flush=True) - self.best_solidx = 0 - self.best_soltime = 0 - print('>>> Fastest Solution is',self.best_solidx,self.best_soltime,flush=True) - - -class GemmTuner: - def __init__(self, indtype, outdtype, tuned_file=None): - self.gemm_problems = pd.DataFrame(columns=['M','N','K']) - self.indtype = indtype - self.outdtype = outdtype - self.tuned_file = tuned_file - if Path(tuned_file).is_file(): - self.gdf = pd.read_csv(tuned_file) - else: - self.gdf = None - - def add_gemm(self,m,n,k): - if ( self.gdf is None or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & (self.gdf['K'] == k)].empty)): - entry = {'M':[m], 'N':[n], 'K':[k]} - df = pd.DataFrame(entry) - self.gemm_problems = pd.concat([self.gemm_problems, df],ignore_index=True) - else: - print(f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping") - - def find_best_sols(self): - df = self.gemm_problems - soldf = pd.DataFrame() - for i in range(len(df)): - ds = df.iloc[i] - gemmobj = Gemm(ds['M'],ds['N'],ds['K'],indtype=self.indtype,outdtype=self.outdtype) - gemmobj.find_fastest_solution() - soldf.loc[i,'solidx'] = gemmobj.best_solidx - soldf.loc[i,'soltimems'] = gemmobj.best_soltime - soldf['indtype'] = self.indtype - soldf['outdtype'] = self.outdtype - finaldf = pd.concat([self.gemm_problems, soldf],axis=1) - finaldf = pd.concat([finaldf, self.gdf]) - finaldf.to_csv(self.tuned_file, index=False) - print(finaldf) diff --git a/gradlib_fp8/setup.py b/gradlib_fp8/setup.py deleted file mode 100644 index 933503ec5d54e..0000000000000 --- a/gradlib_fp8/setup.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import setuptools -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -from torch.utils.hipify import hipify_python -import os -import subprocess -import re - -this_dir = os.path.dirname(os.path.abspath(__file__)) -#gpus = subprocess.check_output("/opt/rocm/bin/rocminfo").decode('UTF-8').split('\n') -#gpus = list(set([re.search('(gfx94.)', g).group(0) for g in gpus if 'gfx94' in g])) -gpus = ['gfx90a','gfx940','gfx941','gfx942'] -#gpus = ['gfx90a','gfx940'] -extra_args = ["--offload-arch=" + g for g in gpus] - - -#sets_rocm_pytorch = False -maj_ver, min_ver, *_ = torch.__version__.split('.') -if int(maj_ver) > 1 or (int(maj_ver) == 1 and int(min_ver) >= 5): - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False - -ext_modules = [] - -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): - generator_flag = ['-DOLD_GENERATOR'] - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) - -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - -include_dirs=[os.path.join(this_dir, 'csrc')] - -#if is_rocm_pytorch: -# import shutil -# with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx: -# hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*", -# show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx) - -if not is_rocm_pytorch: - ext_modules.append( - CUDAExtension( - name='gradlib', - sources=['grad_funcs.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc':['-O3','-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', "--expt-relaxed-constexpr", "-ftemplate-depth=1024", '-gencode=arch=compute_70,code=sm_70','-gencode=arch=compute_80,code=sm_80','-gencode=arch=compute_80,code=compute_80'] - } - ) - ) -elif is_rocm_pytorch: - ext_modules.append( - CUDAExtension( - name='hipbsolidxgemm', - sources=['./csrc/hipbsolgemm.cu'], - include_dirs=include_dirs, - # add additional libraries argument for hipblaslt - libraries=['hipblaslt'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc':['-O3','-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - "-ftemplate-depth=1024"] + extra_args - } - ) - ) - -setup( - name='gradlib', - packages=['gradlib'], - ext_modules=ext_modules, - cmdclass={ - 'build_ext': BuildExtension -}) - -# python setup.py build && cp build/lib*/gradlib* ../ diff --git a/patch_xformers.rocm.sh b/patch_xformers.rocm.sh deleted file mode 100644 index de427b24d306f..0000000000000 --- a/patch_xformers.rocm.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -set -e - -XFORMERS_VERSION="0.0.23" - -export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)') - -if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then - echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed" - exit 1 -fi - -export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') -export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') - -echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}" -echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}" - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" - patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" -else - echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" -fi - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" - patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" -else - echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" -fi diff --git a/pyproject.toml b/pyproject.toml index 180048dee39bd..2af36d1de39be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.2.1", + "torch == 2.3.0", "wheel", ] build-backend = "setuptools.build_meta" @@ -32,6 +32,7 @@ select = [ "SIM", # isort # "I", + "G", ] ignore = [ # star imports @@ -46,16 +47,28 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true +check_untyped_defs = true +follow_imports = "skip" files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" - +exclude = [ + "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", + # Ignore triton kernels in ops. + 'vllm/attention/ops/.*\.py$' +] [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts,./benchmarks/sonnet.txt,./gradlib" +skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./gradlib" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup", + "llm: run tests for vLLM API only", + "openai: run tests for OpenAI API only", +] diff --git a/requirements-build.txt b/requirements-build.txt index 2bc07fb152aac..1a07a94e82e04 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.2.1 +torch==2.3.0 wheel diff --git a/requirements-common.txt b/requirements-common.txt index ff053388a23e1..3ea22276f63f4 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -5,10 +5,17 @@ sentencepiece # Required for LLaMA tokenizer. numpy requests py-cpuinfo -transformers >= 4.39.1 # Required for StarCoder2 & Llava. +transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. +tokenizers >= 0.19.1 # Required for Llama 3. fastapi +aiohttp +openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 -tiktoken == 0.6.0 # Required for DBRX tokenizer -outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file +prometheus-fastapi-instrumentator >= 7.0.0 +tiktoken >= 0.6.0 # Required for DBRX tokenizer +lm-format-enforcer == 0.10.1 +outlines == 0.0.34 # Requires torch >= 2.1.0 +typing_extensions +filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 36d20bc9473ea..b739642d8d344 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. +torch == 2.3.0+cpu +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c04..5109f17356178 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -3,8 +3,7 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 -pynvml == 11.5.0 -vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library -torch == 2.2.1 -xformers == 0.0.25 # Requires PyTorch 2.2.1 -triton >= 2.1.0 +nvidia-ml-py # for pynvml package +torch == 2.3.0 +xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1b..cf2bb9bef22d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,27 +5,30 @@ tomli==2.0.1 ruff==0.1.5 codespell==2.2.6 isort==5.13.2 +clang-format==18.1.5 # type checking -mypy==0.991 +mypy==1.9.0 types-PyYAML types-requests types-setuptools # testing pytest +tensorizer>=2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures pytest-shard -httpx + +# testing utils +awscli einops # required for MPT -openai +httpx +peft requests ray -peft -awscli -ai2-olmo # required for OLMo +sentence-transformers # required for embedding # Benchmarking aiohttp diff --git a/requirements-rocm.txt b/requirements-rocm.txt index d43b256503afc..cc42839a975d0 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,4 +2,5 @@ -r requirements-common.txt # Dependencies for AMD GPUs -ray >= 2.9.3 +ray >= 2.10.0 +pytest-asyncio diff --git a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch b/rocm_patch/commonpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index 4d7495cf13e1d..0000000000000 --- a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,13 +0,0 @@ ---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 -+++ common.py 2023-11-28 16:14:19.846233146 +0000 -@@ -298,8 +298,8 @@ - dtype = d.query.dtype - if device_type not in cls.SUPPORTED_DEVICES: - reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") -- if device_type == "cuda" and not _built_with_cuda: -- reasons.append("xFormers wasn't build with CUDA support") -+ #if device_type == "cuda" and not _built_with_cuda: -+ # reasons.append("xFormers wasn't build with CUDA support") - if device_type == "cuda": - device_capability = torch.cuda.get_device_capability(d.device) - if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index ac846728a7a91..0000000000000 --- a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,152 +0,0 @@ ---- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 -+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 -@@ -36,44 +36,44 @@ - - FLASH_VERSION = "0.0.0" - try: -- try: -- from ... import _C_flashattention # type: ignore[attr-defined] -- from ..._cpp_lib import _build_metadata -- -- if _build_metadata is not None: -- FLASH_VERSION = _build_metadata.flash_version -- except ImportError: -- import flash_attn -- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -- -- FLASH_VERSION = flash_attn.__version__ -- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -- if ( -- flash_ver_parsed != (2, 3, 6) -- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -- ): -- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") -+ #try: -+ # from ... import _C_flashattention # type: ignore[attr-defined] -+ # from ..._cpp_lib import _build_metadata -+ -+ # if _build_metadata is not None: -+ # FLASH_VERSION = _build_metadata.flash_version -+ #except ImportError: -+ import flash_attn -+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -+ -+ FLASH_VERSION = flash_attn.__version__ -+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -+ # if ( -+ # flash_ver_parsed != (2, 3, 6) -+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -+ # ): -+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") - - # create library so that flash-attn goes through the PyTorch Dispatcher -- _flash_lib = torch.library.Library("xformers_flash", "DEF") -- -- _flash_lib.define( -- "flash_fwd(Tensor query, Tensor key, Tensor value, " -- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, " -- "bool is_causal, int window_left, " -- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") - -- _flash_lib.define( -- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, bool is_causal, " -- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib.define( -+ # "flash_fwd(Tensor query, Tensor key, Tensor value, " -+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, " -+ # "bool is_causal, int window_left, " -+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -+ #) -+ -+ #_flash_lib.define( -+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, bool is_causal, " -+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -+ #) - - def _flash_fwd( - query, -@@ -111,8 +111,8 @@ - p, - softmax_scale, - is_causal, -- window_left, # window_size_left -- window_right, # window_size_right -+ # window_left, # window_size_left -+ # window_right, # window_size_right - return_softmax, - None, # rng - ) -@@ -134,15 +134,15 @@ - out, - cu_seq_lens_q, - cu_seq_lens_k, -- seqused_k, -+ # seqused_k, - max_seq_len_q, - max_seq_len_k, - p, - softmax_scale, - False, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - return_softmax, - None, - ) -@@ -184,8 +184,8 @@ - p, - softmax_scale, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) -@@ -208,15 +208,15 @@ - softmax_scale, - False, # zero_tensors - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) - return dq, dk, dv - -- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") -+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") - except ImportError: - pass - -@@ -400,7 +400,7 @@ - implementation. - """ - -- OPERATOR = get_operator("xformers_flash", "flash_fwd") -+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") - SUPPORTED_DEVICES: Set[str] = {"cuda"} - CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) - SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/setup.py b/setup.py index c0c6d1821fd0f..6cd6095adfdd6 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import importlib.util import io import logging import os @@ -5,7 +6,7 @@ import subprocess import sys from shutil import which -from typing import List +from typing import Dict, List import torch from packaging.version import Version, parse @@ -13,10 +14,23 @@ from setuptools.command.build_ext import build_ext from torch.utils.cpp_extension import CUDA_HOME + +def load_module_from_path(module_name, path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) -# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] -VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") + +# cannot import envs directly because it depends on vllm, +# which is not installed yet +envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) + +VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE # vLLM only supports Linux platform assert sys.platform.startswith( @@ -52,7 +66,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. - did_config = {} + did_config: Dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. @@ -60,10 +74,10 @@ class cmake_build_ext(build_ext): def compute_num_jobs(self): # `num_jobs` is either the value of the MAX_JOBS environment variable # (if defined) or the number of CPUs available. - num_jobs = os.environ.get("MAX_JOBS", None) + num_jobs = envs.MAX_JOBS if num_jobs is not None: num_jobs = int(num_jobs) - logger.info(f"Using MAX_JOBS={num_jobs} as the number of jobs.") + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) else: try: # os.sched_getaffinity() isn't universally available, so fall @@ -78,11 +92,12 @@ def compute_num_jobs(self): # environment variable (if defined) or 1. # when it is set, we reduce `num_jobs` to avoid # overloading the system. - nvcc_threads = os.getenv("NVCC_THREADS", None) + nvcc_threads = envs.NVCC_THREADS if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) - logger.info(f"Using NVCC_THREADS={nvcc_threads} as the number" - " of nvcc threads.") + logger.info( + "Using NVCC_THREADS=%d as the number of nvcc threads.", + nvcc_threads) else: nvcc_threads = 1 num_jobs = max(1, num_jobs // nvcc_threads) @@ -103,7 +118,7 @@ def configure(self, ext: CMakeExtension) -> None: # Select the build type. # Note: optimization level + debug info are set by the build type default_cfg = "Debug" if self.debug else "RelWithDebInfo" - cfg = os.getenv("CMAKE_BUILD_TYPE", default_cfg) + cfg = envs.CMAKE_BUILD_TYPE or default_cfg # where .so files will be written, should be the same for all extensions # that use the same CMakeLists.txt. @@ -117,7 +132,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] - verbose = bool(int(os.getenv('VERBOSE', '0'))) + verbose = envs.VERBOSE if verbose: cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] @@ -172,19 +187,22 @@ def build_extensions(self) -> None: if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) + targets = [] # Build all the extensions for ext in self.extensions: self.configure(ext) + targets.append(remove_prefix(ext.name, "vllm.")) - ext_target_name = remove_prefix(ext.name, "vllm.") - num_jobs, _ = self.compute_num_jobs() + num_jobs, _ = self.compute_num_jobs() - build_args = [ - '--build', '.', '--target', ext_target_name, '-j', - str(num_jobs) - ] + build_args = [ + "--build", + ".", + f"-j={num_jobs}", + *[f"--target={name}" for name in targets], + ] - subprocess.check_call(['cmake', *build_args], cwd=self.build_temp) + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) def _is_cuda() -> bool: @@ -204,7 +222,7 @@ def _is_neuron() -> bool: subprocess.run(["neuron-ls"], capture_output=True, check=True) except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False - return torch_neuronx_installed + return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON def _is_cpu() -> bool: @@ -212,7 +230,7 @@ def _is_cpu() -> bool: def _install_punica() -> bool: - return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) + return envs.VLLM_INSTALL_PUNICA_KERNELS def get_hipcc_rocm_version(): @@ -261,6 +279,7 @@ def get_nvcc_cuda_version() -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ + assert CUDA_HOME is not None, "CUDA_HOME is not set" nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = nvcc_output.split() @@ -339,14 +358,15 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") - cuda_major = torch.version.cuda.split(".")[0] + cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if "vllm-nccl-cu12" in req: - modified_requirements.append( - req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) - else: - modified_requirements.append(req) + if ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): + # vllm-flash-attn is built only for CUDA 12.1. + # Skip for other versions. + continue + modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("requirements-rocm.txt") @@ -362,20 +382,23 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] -if _is_cuda() and _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) +if _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._custom_C")) + if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) - ext_modules.append(CMakeExtension(name="vllm._custom_C")) + + if _install_punica(): + ext_modules.append(CMakeExtension(name="vllm._punica_C")) package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } -if os.environ.get("VLLM_USE_PRECOMPILED"): +if envs.VLLM_USE_PRECOMPILED: + ext_modules = [] package_data["vllm"].append("*.so") setup( @@ -401,10 +424,13 @@ def _read_requirements(filename: str) -> List[str]: "Topic :: Scientific/Engineering :: Artificial Intelligence", ], packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples", - "tests")), + "tests*")), python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, + extras_require={ + "tensorizer": ["tensorizer>=2.9.0"], + }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/vllm/model_executor/parallel_utils/__init__.py b/tests/async_engine/__init__.py similarity index 100% rename from vllm/model_executor/parallel_utils/__init__.py rename to tests/async_engine/__init__.py diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 248bfbc8ab5c0..7f57d5cf9b182 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int): +def api_server(tokenizer_pool_size: int, engine_use_ray: bool, + worker_use_ray: bool): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() - uvicorn_process = subprocess.Popen([ + commands = [ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m", "--host", "127.0.0.1", "--tokenizer-pool-size", str(tokenizer_pool_size) - ]) + ] + if engine_use_ray: + commands.append("--engine-use-ray") + if worker_use_ray: + commands.append("--worker-use-ray") + uvicorn_process = subprocess.Popen(commands) yield uvicorn_process.terminate() @pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) -def test_api_server(api_server, tokenizer_pool_size: int): +@pytest.mark.parametrize("worker_use_ray", [False, True]) +@pytest.mark.parametrize("engine_use_ray", [False, True]) +def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool, + engine_use_ray: bool): """ Run the API server and test it. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index cb125a7bfec30..10a46422887e3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): @@ -91,4 +91,6 @@ async def test_new_requests_event(): assert engine.engine.step_calls == old_step_calls + 1 engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None + assert engine.get_decoding_config() is not None diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 6972ae1dee4a1..55b730812ea94 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -76,28 +76,36 @@ def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -def test_no_load_chat_template(): +def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() + mock_serving_chat = MockServingChat(tokenizer) + + with pytest.raises(ValueError, match="looks like a file path"): + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) + + +def test_no_load_chat_template_literallike(): + # Testing chatml template + template = "{{ messages }}" + tokenizer = MockTokenizer() + mock_serving_chat = MockServingChat(tokenizer) OpenAIServingChat._load_chat_template(mock_serving_chat, chat_template=template) template_content = tokenizer.chat_template - # Test assertions - assert template_content is not None - # Hard coded value for template_chatml.jinja - assert template_content == """../../examples/does_not_exist""" + assert template_content == template -@pytest.mark.asyncio @pytest.mark.parametrize( "model,template,add_generation_prompt,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) -async def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): +def test_get_gen_prompt(model, template, add_generation_prompt, + expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py new file mode 100644 index 0000000000000..4c362a0512feb --- /dev/null +++ b/tests/async_engine/test_openapi_server_ray.py @@ -0,0 +1,114 @@ +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray + +from ..utils import ServerRunner + +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--enforce-eager", + "--engine-use-ray" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="module") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.mark.asyncio +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.asyncio +async def test_single_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.asyncio +async def test_single_chat_session(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert chat_completion.choices is not None and len( + chat_completion.choices) == 1 + assert chat_completion.choices[0].message is not None + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/basic_correctness/__init__.py b/tests/basic_correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..7d8117447ca0a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,12 +2,28 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ +import os +import weakref + import pytest +from vllm import LLM + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + + +def test_vllm_gc_ed(): + """Verify vllm instance is GC'ed when it is deleted""" + llm = LLM("facebook/opt-125m") + weak_llm = weakref.ref(llm) + del llm + # If there's any circular reference to vllm, this fails + # because llm instance is not GC'ed. + assert weak_llm() is None @pytest.mark.parametrize("model", MODELS) @@ -23,11 +39,18 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER" and enforce_eager is False: + pytest.skip("Skipping non-eager test for FlashInferBackend.") + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 0000000000000..47d582c726c66 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,65 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py new file mode 100644 index 0000000000000..29a4c39cd25a1 --- /dev/null +++ b/tests/basic_correctness/test_preemption.py @@ -0,0 +1,265 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. + +Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 +pytest tests/basic_correctness/test_preemption.py`. +""" +import pytest +from prometheus_client import REGISTRY + +from vllm import SamplingParams +from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, + ENABLE_ARTIFICIAL_PREEMPT) + +MODELS = [ + "facebook/opt-125m", +] + +assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + """Ensure that chunked prefill works with preemption.""" + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption( + caplog_vllm, + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """By default, recompute preemption is enabled""" + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + disable_log_stats=False, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap( + caplog_vllm, + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Use beam search enables swapping.""" + example_prompts = example_prompts[:1] + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + disable_log_stats=False, + ) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, _ = hf_outputs[i] + vllm_output_ids, _ = vllm_outputs[i] + assert len(hf_output_ids) == len(vllm_output_ids) + for j in range(len(hf_output_ids)): + assert hf_output_ids[j] == vllm_output_ids[j], ( + f"Test{i} output{j}:\nHF: {hf_output_ids}\n" + f"vLLM: {vllm_output_ids}") + + assert ("is preempted by PreemptionMode.SWAP mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Verify infeasible swap request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + example_prompts = example_prompts[:1] + + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + block_size=BLOCK_SIZE, + # Since beam search have more than 1 sequence, prefill + decode blocks + # are not enough to finish. + num_gpu_blocks_override=prefill_blocks + decode_blocks, + max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + ) + sampling_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + assert req_outputs[0].outputs[0].finish_reason == "length" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """Verify infeasible preemption request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + vllm_model = vllm_runner( + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + ) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + for req_output in req_outputs: + outputs = req_output.outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == "length" diff --git a/tests/conftest.py b/tests/conftest.py index e00f3eb871e37..af04cfbbb9902 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,22 @@ import contextlib import gc import os -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest import torch from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, - LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + LlavaConfig, LlavaForConditionalGeneration) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) +from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs +from vllm.logger import init_logger from vllm.sequence import MultiModalData -from vllm.transformers_utils.tokenizer import get_tokenizer + +logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -130,9 +132,11 @@ def example_long_prompts() -> List[str]: "float": torch.float, } -_VISION_LANGUAGE_MODELS = { - "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, -} +AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) + +_EMBEDDING_MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] class HfRunner: @@ -140,32 +144,44 @@ class HfRunner: def __init__( self, model_name: str, - tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name - if model_name not in _VISION_LANGUAGE_MODELS: - self.model = AutoModelForCausalLM.from_pretrained( + + if model_name in _EMBEDDING_MODELS: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer( model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = None + device="cpu", + ).to(dtype=torch_dtype).cuda() else: - self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( + self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True, ).cuda() + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, + trust_remote_code=True, ) - if tokenizer_name is None: - tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", model_name) + self.processor = self.tokenizer def generate( self, @@ -177,19 +193,19 @@ def generate( if images: assert len(prompts) == len(images) for i, prompt in enumerate(prompts): - if self.model_name not in _VISION_LANGUAGE_MODELS: - input_ids = self.tokenizer(prompt, - return_tensors="pt").input_ids - inputs = {"input_ids": input_ids.cuda()} - else: - image = images[i] if images else None - inputs = self.processor(text=prompt, - images=image, - return_tensors="pt") - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } + output_ids = self.model.generate( **inputs, use_cache=True, @@ -273,6 +289,71 @@ def generate_greedy_logprobs( all_logprobs.append(seq_logprobs) return all_logprobs + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + all_logprobs = [] + all_output_ids = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + def __del__(self): del self.model cleanup() @@ -297,6 +378,7 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, + swap_space=4, **kwargs, ) -> None: self.model = LLM( @@ -304,7 +386,7 @@ def __init__( tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, - swap_space=0, + swap_space=swap_space, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, @@ -321,12 +403,22 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) + + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt @@ -397,12 +489,20 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + def __del__(self): del self.model cleanup() -@pytest.fixture +@pytest.fixture(scope="session") def vllm_runner(): return VllmRunner @@ -415,3 +515,19 @@ def get_tokenizer_pool_config(tokenizer_group_type): pool_type="ray", extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog diff --git a/tests/core/block/e2e/__init__.py b/tests/core/block/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index 1d99cb5d32219..e870597b7a011 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,9 +1,12 @@ +from typing import Callable, Iterable, Optional + import pytest -from tests.conftest import cleanup from vllm import LLM from vllm.model_executor.utils import set_random_seed +from ....conftest import cleanup + @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, @@ -39,3 +42,27 @@ def generator_inner(): for llm in generator_inner(): yield llm del llm + + +def get_text_from_llm_generator(llm_generator: Iterable[LLM], + prompts, + sampling_params, + llm_cb: Optional[Callable[[LLM], + None]] = None): + for llm in llm_generator: + if llm_cb: + llm_cb(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + text = [output.outputs[0].text for output in outputs] + del llm + + return text + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 94b65401e1dd4..3713ef2fed4d1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -4,6 +4,8 @@ from vllm import SamplingParams +from .conftest import get_token_ids_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -230,10 +232,217 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, + ]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [ + { + "use_v2_block_manager": False, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "use_v2_block_manager": True, + "num_lookahead_slots": 0, + }, + { + "use_v2_block_manager": True, + "num_lookahead_slots": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_chunked_prefill_block_manager_v2(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify that chunked prefill works with BlockManagerV2, with and without + lookahead scheduling. + """ + output_len = 32 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with BlockManagerV1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with BlockManagerV2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), - return token_ids + # Enable prefill cache + "enable_prefix_caching": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size): + """Verify block manager v2 produces same outputs as block manager v1, even + when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that the KV + cache is not corrupted in the v2 block manager. + + NOTE: We want a significant number of generated tokens so that any incorrect + KV mapping has time to build up error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Test APC in v2 block + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "enable_prefix_caching": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_auto_prefix_caching_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager v2 with auto prefix caching enabled produces same + outputs as auto prefix caching disabled, even when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that auto + prefix caching itself at least don't cause result error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with APC disabled') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with APC enabled') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..e98292e807d73 --- /dev/null +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -0,0 +1,168 @@ +import random +from typing import List + +import pytest + +from vllm import LLM, SamplingParams + +from .conftest import get_text_from_llm_generator + +# relatively small model with 4k sliding window +MODEL = "bigcode/starcoder2-3b" +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, + batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + + Additionally, we compare the results of the v1 and v2 managers. + """ + sampling_params = SamplingParams( + max_tokens=1024, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + print('Getting token ids from block manager v1') + baseline_texts = get_text_from_llm_generator(baseline_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + + check_answers(indices, answer, baseline_texts) + + print('Getting token ids from block manager v2') + test_texts = get_text_from_llm_generator(test_llm_generator, prompts, + sampling_params) + check_answers(indices, answer, test_texts) + + cmp = [ + expected_text == actual_text + for expected_text, actual_text in zip(baseline_texts, test_texts) + ] + print(cmp) + # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 + # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 + # states that xformers and flash_attn have different ideas about the window + # size anyways + assert sum(cmp) > 0.7 * len(cmp) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "enable_chunked_prefill": True +}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): + """ + This is similar to test_sliding_window_retrival, however, it doesn't + compare against the v1 block manager since v1 doesn't support + chunked prefill with sliding window. + + The results with and without chunked prefill are not the same due to + numerical instabilities. + """ + sampling_params = SamplingParams( + max_tokens=10, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + # We don't compare with the baseline model here, since the results + # slightly different due to different tailing in attention. + test_texts = get_text_from_llm_generator(test_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + check_answers(indices, answer, test_texts) + + +def prep_prompts(batch_size: int): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + """ + prompts: List[str] = [] + answer: List[int] = [] + indices: List[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(800, 1100) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: List[int], answer: List[int], outputs: List[str]): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok > 0.7 + + +def check_window(prompts: List[str]): + + def inner(llm: LLM): + sliding_window = llm.llm_engine.model_config.get_sliding_window() + assert sliding_window and sliding_window > 0 + assert any( + len(llm.get_tokenizer().tokenize(prompt)) > sliding_window + for prompt in prompts) + + return inner diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..f98fc0e217278 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,11 +1,13 @@ import pytest +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group +from ..utils import create_seq_group, create_seq_group_encoder_decoder @pytest.mark.parametrize("block_size", [16]) @@ -52,6 +54,156 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group_encoder_decoder(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for bdx, num_prompt_blocks in enumerate( + range(1, num_gpu_blocks - num_output_blocks)): + num_cross_blocks_per_seq = num_prompt_blocks + + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id=str(bdx)) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + \ + num_output_blocks + \ + num_cross_blocks_per_seq + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + ''' + SWA short for Sliding Window Attention. + + At time of writing block manager v2 does not support SWA. + + However even when SWA is implemented for block manager v2, + there will still most likely be a separate workstream required + to enable SWA for encoder/decoder models. + + Therefore this test enforces that one of the following cases + hold true: + 1. Block manager v2 does not support SWA at all (true at time of writing) + 2. Block manager v2 fails with NotImplementError when SWA is enabled + AND a SequenceGroup with an encoder sequence (i.e. in support of an + encoder/decoder model) is passed into can_allocate() as an argument + + The setup for this test is stripped down version of + test_can_allocate_seq_group_encoder_decoder() + ''' + + with pytest.raises((NotImplementedError, AssertionError)) as exc_info: + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + sliding_window=5 # SWA + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + block_manager.can_allocate(seq_group) + + # Assert that either + # 1. Block manager v2 constructor fails with assertion that sliding window + # is not yet supported (most likely near-term outcome at time of + # writing), or + # 2. can_allocate() fails with NotImplementedError due to combination of + # encoder/decoder and sliding window attention + if isinstance(exc_info.value, NotImplementedError): + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + elif isinstance(exc_info.value, AssertionError): + assert str(exc_info.value) == "Sliding window not yet supported" + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_prefix_cache( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, + watermark: float): + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + enable_caching=True # Prefix cache + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + # Assert that either can_allocate() fails with NotImplementedError + # due to combination of encoder/decoder and prefix cache + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) @@ -101,3 +253,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, range(prompt_len + num_slots_to_append + num_lookahead_slots)), block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) assert num_consumed_blocks == expected_consumed_blocks + + +@pytest.mark.parametrize("block_size", [8, 16]) +@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) +@pytest.mark.parametrize("num_slots_to_append", [50]) +@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) +def test_sliding_window(block_size, prompt_len, num_slots_to_append, + sliding_window): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + sliding_window=sliding_window, + ) + + def check_used(min_n, max_n=None): + if max_n is None: + max_n = min_n + used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() + #print("check", min_n, used, max_n) + assert min_n <= used + assert used <= max_n + + def num_blocks(num_tokens): + return (num_tokens + block_size - 1) // block_size + + check_used(0) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + check_used(0) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + check_used(num_blocks(prompt_len)) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + seq.data.update_num_computed_tokens(prompt_len) + check_used(num_blocks(prompt_len)) + + # this is how we compute it in BlockSpaceManagerV2.__init__ + sliding_blocks = (sliding_window // block_size) + 2 + # plus one block for null block + sliding_blocks += 1 + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + seq.data.update_num_computed_tokens(1) + block_manager.append_slots(seq, num_lookahead_slots=0) + if prompt_len < sliding_window + 10: + check_used(0, sliding_blocks + 1) + else: + check_used(sliding_blocks, sliding_blocks + 1) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 3481d6b4312c1..6fb95cfdfab81 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -410,8 +410,7 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows else: # Otherwise, there should be no copy-on-write. assert not cows @@ -490,8 +489,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows static_block_table.free() appender_block_table.free() diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5f4d58dd5fd39..bcf08cda09f46 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -358,6 +358,248 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, i) allocator.free(block) + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_common_computed_block_ids(num_blocks: int, block_size: int, + seed: int): + """Verify get_common_computed_block_ids could get correct result + by create two immutable chain sharing prefix at specified pos, + and compare whether we also could get right result + from get_common_computed_block_ids. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + blocks = list(range(num_blocks_to_consume)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # mark all blocks in first chain as computed + allocator.mark_blocks_as_computed(blocks) + + # After zero_point, second_chain's token_ids would be set -1, which + # make it different from here comparing with first_chain + zero_point = random.randint(1, len(token_ids) - 1) + zero_point_blocks = zero_point // block_size + token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) + + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + first_computed_ids = [ + first_chain[i].block_id for i in range(num_blocks_to_consume) + ] + second_computed_ids = [ + second_chain[i].block_id for i in range(num_blocks_to_consume) + ] + res = allocator.get_common_computed_block_ids( + [first_computed_ids, second_computed_ids]) + + assert (len(res) == zero_point_blocks) + + # Test case that assume those prompted block after first immutable would + # be freed into hashless allocator, while first immutable block get ref + # increased. + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(block_size)) + + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids) + + assert allocator._refcounter.get(block.block_id) == 1 + m = allocator.allocate_mutable(prev_block=None) + + block_id = m.block_id + for i in range(block_size): + m.append_token_ids([i]) + # After block get promoted to immutable from mutable, if there is + # already same content hash block, then it shall be released into + # hashless_allocator + # And first immutable block's ref get increased by 1 + assert m.block_id == block.block_id + assert block_id in allocator._hashless_allocator._free_block_indices + assert allocator._refcounter.get(block.block_id) == 2 + + # Test case when eviction and allocation are mixed, + # make sure they work as expected + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + all_blocks_list = [i for i in range(num_blocks)] + zero_ref = {i: 0 for i in range(num_blocks)} + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(num_blocks * block_size)) + + # now we have num_blocks free blocks in hashless allocator + # with internal tracking list _blocks _cached_blocks and evictor + # empty and block's ref shall be 0 + assert list(allocator._hashless_allocator._free_block_indices + ) == all_blocks_list + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._cached_blocks.values()) == 0 + assert len(allocator.evictor.free_table.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + + # Allocate immutable chains with only one block residuled in + new_block = [] + for i in range(num_blocks): + block = allocator.allocate_immutable( + prev_block=None, + token_ids=token_ids[block_size * i:block_size * (i + 1)]) + new_block.append(block) + + # Free all blocks, and now all blocks shall be in the evictor + # there shall be no tracking data left in _blocks + # all blocks shall be tracked in _cached_blocks + # all blocks' ref shall be zero + for block in new_block: + allocator.free(block) + + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert list(allocator._cached_blocks.values()) == all_blocks_list + assert list(allocator.evictor.free_table.keys()) == all_blocks_list + assert allocator._refcounter._refcounts == zero_ref + + # Allocate a mutable block, and the first block shall be evicted + # and set its content hash into None, ref to 1 + mutable = allocator.allocate_mutable(prev_block=None) + + assert mutable.block_id == 0 + assert mutable.content_hash is None + assert 0 in allocator._blocks + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + + # Since this mutable block has no hash yet, it shall be released into + # hashless allocator + allocator.free(mutable) + + assert len(allocator._blocks.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + assert 0 in allocator._hashless_allocator._free_block_indices + + # when allocate immutable with first block_size tokens, we + # shall get free block from hashless allocator, thus no block left + # in hashless + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids[:block_size]) + + assert block.block_id == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert 0 in allocator._blocks + assert 0 in allocator._cached_blocks.values() + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator.evictor + + # allocate mutable block again, it shall be popped from evictor + mutable = allocator.allocate_mutable(prev_block=None) + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert mutable.block_id not in allocator.evictor.free_table + assert allocator._refcounter.get(mutable.block_id) == 1 + + # Test case where two last accessed times are equal + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order(num_blocks: int, block_size: int, seed: int): + """This test case simulate the two chain created and free in order, + and together they would exhaust the initial freed blocks. + + So the next block created after those two chain shall use the block + from the first chain as that block has long access time. + While first chain has two blocks, it shall pick up the last one, as + it has larger token number. + """ + + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + num_blocks_in_first_chain = 2 + num_tokens_in_first_chain = block_size * num_blocks_in_first_chain + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:num_tokens_in_first_chain], + allocator=allocator, + ) + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_in_first_chain) + + # Set the last accessed time of the first block to 1 + blocks_ids = [block.block_id for block in first_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[num_tokens_in_first_chain:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + assert len(first_chain) == num_blocks_in_first_chain + last_block_id = first_chain[-1].block_id + # Free each block in the first chain. + for i, block in enumerate(first_chain): + allocator.free(block) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + blocks_ids = [block.block_id for block in second_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == last_block_id + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4caabb..cd306b9e4d3cc 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -1,17 +1,20 @@ import time +from collections import defaultdict from typing import List import pytest from vllm import SamplingParams from vllm.block import PhysicalTokenBlock +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, UncachedBlockAllocator) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from .utils import create_dummy_prompt +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def test_block_allocator_allocate(): @@ -72,7 +75,7 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -84,11 +87,107 @@ def test_allocate(): watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + +def test_allocate_encoder_decoder_fails_with_swa(): + # SWA short for sliding window attention + + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + sliding_window=5) # swa + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + # Assert that allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + +def test_allocate_encoder_decoder_fails_with_prefix_caching(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=True) # Prefix cache + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + # Assert that allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 @@ -132,8 +231,10 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -141,8 +242,10 @@ def test_append_slot_cow(): child = prompt.fork(new_seq_id=2) # Allocate space for the sequence group. - seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), - time.time(), time.perf_counter) + seq_group = SequenceGroup(request_id="1", + seqs=[prompt, child], + arrival_time=time.time(), + sampling_params=SamplingParams()) block_manager.allocate(seq_group) # Fork and append a new token id. We expect a COW to be scheduled. @@ -155,7 +258,10 @@ def test_append_slot_cow(): cows = block_manager.append_slots(child) assert cows - for src_block, dst_blocks in cows.items(): + dict_cows = defaultdict(list) + for src_block, dst_block in cows: + dict_cows[src_block].append(dst_block) + for src_block, dst_blocks in dict_cows.items(): assert src_block not in dst_blocks after_blocks = block_manager.get_num_free_gpu_blocks() @@ -215,7 +321,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert list(mapping.keys()) == gpu_blocks + assert [x[0] for x in mapping] == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) @@ -224,11 +330,67 @@ def test_swap(): # Swap seq group from CPU -> GPU. cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert list(mapping.keys()) == cpu_blocks + assert [x[0] for x in mapping] == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + +def test_swap_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + decoder_prompt.status = SequenceStatus.WAITING + encoder_prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + decoder_prompt.status = SequenceStatus.RUNNING + decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap encoder/decoder seq group from GPU -> CPU. + decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert [x[0] for x in mapping] == gpu_blocks + #assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + decoder_prompt.status = SequenceStatus.SWAPPED + + # Swap encoder/decoder seq group from CPU -> GPU. + decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert [x[0] for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks @@ -259,6 +421,41 @@ def test_free(): block_manager.get_block_table(prompt) +def test_free_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + + # Free allocated seq. + decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) + encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) + prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(decoder_prompt) + block_manager.free_cross(seq_group) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(decoder_prompt) + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(encoder_prompt) + + def test_reset(): block_size = 4 num_cpu_blocks = 4 @@ -280,6 +477,31 @@ def test_reset(): assert block_manager.get_num_free_gpu_blocks() == original_blocks +def test_reset_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks + + def test_sliding_window_multi_seq(): """ Tests that memory allocation and deallocation is handled @@ -298,9 +520,17 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) - seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), - None) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + }, + block_size=block_size) + seq_group = SequenceGroup(request_id="1", + seqs=[parent], + arrival_time=time.time(), + sampling_params=SamplingParams(), + lora_request=None) block_manager.allocate(seq_group) # assert the number of blocks allocated is correct diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898f..3649e6b003a5d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,6 +4,7 @@ import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -104,10 +105,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +158,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 @@ -354,8 +355,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap should be prioritized over new prefill. _, seq_group = create_dummy_prompt("2", prompt_length=60) @@ -364,8 +365,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_running_prefill_prioritized_over_swap(): @@ -405,12 +406,12 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER _, seq_group2 = create_dummy_prompt("2", prompt_length=60) scheduler.add_seq_group(seq_group2) @@ -418,18 +419,18 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = True + scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -439,8 +440,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -450,8 +451,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_chunked_prefill_preempt(): @@ -492,8 +493,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == [] + assert out.blocks_to_swap_in == [] # Make sure we can reschedule preempted request. _, out = schedule_and_update_computed_tokens(scheduler) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 9588a1bead5f6..07fc8731e1847 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort(): and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 2 + assert out.preempted == 1 # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") @@ -293,8 +294,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 2 assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. @@ -305,8 +306,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 3 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def initialize_scheduler(*, @@ -540,7 +541,7 @@ def test_decode_schedule_preempted(): curr_loras = None for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) running.append(seq_group) scheduler.block_manager.can_append_slots = MagicMock() @@ -563,11 +564,12 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(output.preempted) == 2 # Verify budgets are updated. assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 1 + # NOTE: When enable_chunk is False, num_seqs budget is not updated. + # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_decode_swap_beam_search(): @@ -581,7 +583,7 @@ def test_decode_swap_beam_search(): budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) running.append(seq_group) append_new_token_seq_group(60, seq_group, 1) budget.add_num_seqs(seq_group.request_id, @@ -598,7 +600,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_append_slots.side_effect = ( cannot_append_second_group) scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = {"5": "7"} + expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping remainig_running, output = scheduler._schedule_running( @@ -617,7 +619,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == expected_swap_mapping # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_schedule_decode_blocks_to_copy_update(): @@ -629,13 +631,13 @@ def test_schedule_decode_blocks_to_copy_update(): running = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) running.append(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_running, output = scheduler._schedule_running( @@ -646,10 +648,10 @@ def test_schedule_decode_blocks_to_copy_update(): assert len(output.preempted) == 0 assert len(output.swapped_out) == 0 # Nothing is preempted. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Since append_slot returns the source -> dist mapping, it should # applied. - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_schedule_swapped_simple(): @@ -657,9 +659,9 @@ def test_schedule_swapped_simple(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -673,9 +675,9 @@ def test_schedule_swapped_simple(): assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 # swap in is the reverse of swap out - blocks_to_swap_in_reverse = {} - for swapin, swapout in output.blocks_to_swap_in.items(): - blocks_to_swap_in_reverse[swapout] = swapin + blocks_to_swap_in_reverse = [] + for swapin, swapout in output.blocks_to_swap_in: + blocks_to_swap_in_reverse.append((swapout, swapin)) assert blocks_to_swap_out == blocks_to_swap_in_reverse @@ -684,10 +686,10 @@ def test_schedule_swapped_max_token_budget(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -718,10 +720,10 @@ def test_schedule_swapped_max_seqs(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -751,7 +753,7 @@ def test_schedule_swapped_max_loras(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = set() - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -759,7 +761,7 @@ def test_schedule_swapped_max_loras(): lora_name=str(i), lora_int_id=i + 1, lora_local_path="abc")) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -780,17 +782,17 @@ def test_schedule_swapped_cannot_swap_in(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -802,21 +804,49 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 +def test_infeasible_swap(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = [] + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.infeasible_seq_groups) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out = {} + blocks_to_swap_out = [] scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -824,7 +854,7 @@ def test_schedule_swapped_blocks_to_copy(): assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_scheduling_budget(): diff --git a/tests/core/utils.py b/tests/core/utils.py index fbbdb07cb8e6e..2fbf099c5f90b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from vllm import SamplingParams from vllm.lora.request import LoRARequest @@ -21,32 +21,88 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + }, + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[prompt], + arrival_time=time.time(), + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + lora_request=lora_request) return prompt, seq_group +def create_dummy_prompt_encoder_decoder( + request_id: str, + decoder_prompt_length: int, + encoder_prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = decoder_prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + decoder_prompt_tokens = list(range(decoder_prompt_length)) + decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + + decoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + encoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) + + return decoder_prompt, encoder_prompt, seq_group + + def create_seq_group( - seq_prompt_len=1024, - seq_output_lens=(128, ), - request_id='0', - seq_id_start=0, -) -> SequenceGroup: + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: assert len(seq_output_lens) > 0 + if sampling_params is None: + sampling_params = SamplingParams() + prompt_token_ids = [0] * seq_prompt_len seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={"prompt_token_ids": prompt_token_ids}, block_size=16, ) @@ -60,12 +116,63 @@ def create_seq_group( seq_group = SequenceGroup( request_id=request_id, seqs=seqs, - sampling_params=SamplingParams(), + sampling_params=sampling_params, arrival_time=time.time(), ) return seq_group +def create_seq_group_encoder_decoder( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + # Encoder sequence + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + return SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size + return (seq_len + block_size - 1) // block_size \ No newline at end of file diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a6422..3ba5cea389c2f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -4,10 +4,12 @@ variables. Run: ```sh +cd $VLLM_PATH/tests + TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py ``` """ import os @@ -18,6 +20,8 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -33,11 +37,21 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + enforce_eager = backend_by_env_var == "FLASHINFER" + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager, + distributed_executor_backend=distributed_executor_backend) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 0000000000000..db938cc613c6b --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,70 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d1811cb694db6..53654dc40d10d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,15 +8,16 @@ import ray import torch -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from vllm.distributed import (broadcast_tensor_dict, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) -def all_reduce_test_worker(tensor_parallel_size: int, rank: int, +def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -24,12 +25,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tensor_parallel_size) + (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank] @@ -38,7 +39,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def all_gather_test_worker(tensor_parallel_size: int, rank: int, +def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -46,7 +47,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -57,7 +58,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, all_tensors = [ torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(tensor_size) * (r + 1) - for r in range(tensor_parallel_size) + for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) t = all_tensors[rank] @@ -66,7 +67,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, +def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -74,17 +75,21 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { + # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), - "b": torch.arange(16, dtype=torch.int8, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], "e": { "a": 1, "b": 2 }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), } if rank == 0: @@ -97,14 +102,15 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert recv_dict["c"] == test_dict["c"] assert recv_dict["d"] == test_dict["d"] assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker ]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) +def test_multi_process_tensor_parallel(tp_size, test_target): + multi_process_tensor_parallel(tp_size, 1, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528c..186f9faa6bfb6 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,11 +6,13 @@ import torch import torch.distributed as dist -from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from vllm.distributed.communication_op import ( # noqa + graph_capture, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_ca_communicator) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -19,17 +21,36 @@ @ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce(world_size, rank, distributed_init_port): +def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) - custom_ar.init_custom_ar() + group = get_tensor_model_parallel_group() + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_ar.capture(): + with graph_capture() as graph_capture_context: # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -41,45 +62,54 @@ def graph_allreduce(world_size, rank, distributed_init_port): device=torch.cuda.current_device()) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2) + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) graph.replay() assert torch.allclose(out1, inp1) assert torch.allclose(out2, inp2) @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce(world_size, rank, distributed_init_port): +def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 sz = 1024 - custom_ar.init_custom_ar() - fa = custom_ar.get_handle() + fa = get_tp_ca_communicator() inp = torch.ones(sz, dtype=torch.float32, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) - - -if __name__ == "__main__": - multi_process_tensor_parallel(2, graph_allreduce) +def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 29782045130a6..0218295a3e3f9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,16 +3,22 @@ import pytest import torch +import torch.distributed -from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetUniqueId) +from vllm.distributed.communication_op import ( # noqa + graph_capture, tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils import update_environment_variables def distributed_run(fn, world_size): number_of_processes = world_size processes = [] for i in range(number_of_processes): - env = os.environ.copy() + env = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) @@ -26,26 +32,34 @@ def distributed_run(fn, world_size): for p in processes: p.join() + for p in processes: + assert p.exitcode == 0 + -def update_env(fn): +def worker_fn_wrapper(fn): # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function - def wrapper(env): - import os - os.environ.update(env) + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + init_distributed_environment() fn() - return wrapper + return wrapped_fn -@update_env +@worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -54,23 +68,82 @@ def test_pynccl(): distributed_run(worker_fn, 2) -@update_env +@worker_fn_wrapper +def multiple_allreduce_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + pynccl_comm = PyNcclCommunicator(group=group, device=device) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_allreduce(): + # this tests pynccl for multiple tp groups, in a standalone way + # i.e. call `pynccl_comm.all_reduce` directly + distributed_run(multiple_allreduce_worker_fn, 4) + + +@worker_fn_wrapper +def multiple_allreduce_with_vllm_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + ensure_model_parallel_initialized(2, 2) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + with graph_capture(): + # two tp groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = tensor_model_parallel_all_reduce(tensor) + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_allreduce_with_vllm(): + # this tests pynccl for multiple tp groups, together with vllm + # i.e. call `tensor_model_parallel_all_reduce` + distributed_run(multiple_allreduce_with_vllm_worker_fn, 4) + + +@worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -79,8 +152,71 @@ def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) +@worker_fn_wrapper +def send_recv_worker_fn(): + pynccl_comm = PyNcclCommunicator() + if pynccl_comm.rank == 0: + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + else: + tensor = torch.empty(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + assert result == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_send_recv(): + distributed_run(send_recv_worker_fn, 2) + + +@worker_fn_wrapper +def multiple_send_recv_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + pynccl_comm = PyNcclCommunicator(group=group, device=device) + if torch.distributed.get_rank() == 0: + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + elif torch.distributed.get_rank() == 1: + tensor = 2 * torch.ones( + 16, 1024, 1024, dtype=torch.float32, device=device) + else: + tensor = torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=device) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + if torch.distributed.get_rank() in [0, 2]: + assert result == 1 + else: + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_send_recv(): + distributed_run(multiple_send_recv_worker_fn, 4) + + def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/tests/engine/__init__.py b/tests/engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/__init__.py b/tests/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py new file mode 100644 index 0000000000000..4f32a622546f0 --- /dev/null +++ b/tests/engine/output_processor/test_multi_step.py @@ -0,0 +1,271 @@ +import random +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +from ...core.utils import create_seq_group + + +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [1, 12]) +@pytest.mark.skip_global_cleanup +def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): + """Verify multi-step decoding appends token ids correctly. + + We append token ids and verify all the token ids were appended correctly. + Note that ignore_eos=True. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=1024, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=seq_output_len + + num_new_tokens, + ignore_eos=True), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids + output_processor.process_outputs(seq_group, outputs) + assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) +@pytest.mark.parametrize("max_tokens", [128 + 3]) +@pytest.mark.skip_global_cleanup +def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, max_tokens: int): + """Verify tokens after max_tokens are dropped and not appended to the + sequence. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=max_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go over max tokens in len. + assert seq.get_len() == seq_prompt_len + max_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """Verify the eos token id is included in the sequence, but subsequent + tokens are dropped (not appended to sequence). + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go beyond provided eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:eos_index + 1] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """When sampling parameters dictate that we should ignore the eos token id, + ensure all token ids are appended even if the eos token id is emitted. + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, + ignore_eos=True, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to go beyond eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - + seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +def mock_tokenizer(eos_token_id=1000): + tokenizer = MagicMock(spec=PreTrainedTokenizer) + tokenizer.eos_token_id = eos_token_id + return tokenizer diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py new file mode 100644 index 0000000000000..f795403e3d8ad --- /dev/null +++ b/tests/engine/output_processor/test_stop_checker.py @@ -0,0 +1,85 @@ +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, Sequence, SequenceStatus + + +def sequence_with_eos(text: str, eos_token: str, + eos_token_id: int) -> Sequence: + """ + Create a Sequence that ends with an EOS token. + """ + seq = Sequence( + seq_id=0, + inputs={"prompt_token_ids": []}, + block_size=16, + eos_token_id=eos_token_id, + ) + seq.output_text = text + eos_token + + offset = eos_token_id + 1 + for i in range(offset, len(text) + offset): + seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) + seq.append_token_id(token_id=eos_token_id, + logprobs={eos_token_id: Logprob(0.0)}) + + seq.status = SequenceStatus.RUNNING + + return seq + + +@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ + ("This text ends with EOS token", "", 2), +]) +@pytest.mark.parametrize("ignore_eos", [True, False, None]) +@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) +@pytest.mark.skip_global_cleanup +def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, + ignore_eos: bool, include_stop_str_in_output: bool): + """ + Test the behavior of the StopChecker's maybe_stop_sequence method + when an EOS token is encountered. + + This test covers: + - When the EOS token should stop the sequence and be removed from the output + - When the EOS token should stop the sequence and be included in the output + - When the EOS token should be ignored, and the sequence continues + """ + + tokenizer = MagicMock(spec=PreTrainedTokenizer) + get_tokenizer_for_seq = MagicMock(return_value=tokenizer) + stop_checker = StopChecker(max_model_len=1024, + get_tokenizer_for_seq=get_tokenizer_for_seq) + + seq = sequence_with_eos( + text=text_wo_eos, + eos_token=eos_token, + eos_token_id=eos_token_id, + ) + new_char_count = len(eos_token) + + # Note that `stop` and `stop_token_ids` are not specified + sampling_params = SamplingParams( + min_tokens=1, + ignore_eos=ignore_eos, + include_stop_str_in_output=include_stop_str_in_output) + + stop_checker.maybe_stop_sequence( + seq=seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + if ignore_eos: + assert seq.status == SequenceStatus.RUNNING + assert seq.output_text == text_wo_eos + eos_token + elif include_stop_str_in_output: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos + eos_token + else: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py new file mode 100644 index 0000000000000..610ad9732fb91 --- /dev/null +++ b/tests/engine/test_multiproc_workers.py @@ -0,0 +1,176 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker.Worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.process.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(5) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py new file mode 100644 index 0000000000000..338b208723ba9 --- /dev/null +++ b/tests/engine/test_skip_tokenizer_init.py @@ -0,0 +1,23 @@ +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_skip_tokenizer_initialization(model: str): + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM(model=model, skip_tokenizer_init=True) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + with pytest.raises(ValueError) as err: + llm.generate("abc", sampling_params) + assert "prompts must be None if" in str(err.value) + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py similarity index 86% rename from tests/samplers/test_stop_reason.py rename to tests/engine/test_stop_reason.py index b242c405a4fb6..7b886507c04f2 100644 --- a/tests/samplers/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -3,7 +3,7 @@ 2. One of the provided stop tokens 3. The EOS token -Run `pytest tests/samplers/test_stop_reason.py`. +Run `pytest tests/engine/test_stop_reason.py`. """ import pytest @@ -32,6 +32,7 @@ def test_stop_reason(vllm_model, example_prompts): # test stop token outputs = llm.generate(example_prompts, sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop_token_ids=[stop_token_id])) @@ -43,7 +44,10 @@ def test_stop_reason(vllm_model, example_prompts): # test stop string outputs = llm.generate(example_prompts, sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS, stop=".")) + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop=".")) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py new file mode 100644 index 0000000000000..6b747beb4b543 --- /dev/null +++ b/tests/engine/test_stop_strings.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional + +import pytest + +from vllm import CompletionOutput, LLMEngine, SamplingParams + +MODEL = "meta-llama/llama-2-7b-hf" +MAX_TOKENS = 200 + + +@pytest.fixture(scope="session") +def vllm_model(vllm_runner): + return vllm_runner(MODEL) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".") + + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".") + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo") + + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo") + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani") + + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani") + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + # token id 13013 => " organization" + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013) + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013) + + +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py new file mode 100644 index 0000000000000..c45f02fe564a3 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -0,0 +1,46 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" + +pytestmark = pytest.mark.openai + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + embedding_mode = False + + +@dataclass +class MockEngine: + + async def get_model_config(self): + return MockModelConfig() + + +async def _async_serving_chat_init(): + engine = MockEngine() + model_config = await engine.get_model_config() + + serving_completion = OpenAIServingChat(engine, + model_config, + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE) + return serving_completion + + +def test_async_serving_chat_init(): + serving_completion = asyncio.run(_async_serving_chat_init()) + assert serving_completion.tokenizer is not None + assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5622744566bcc..5d4163e96fd87 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,11 +1,14 @@ # This unit test should be moved to a new # tests/test_guided_decoding directory. - +import pytest import torch from transformers import AutoTokenizer -from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + JSONLogitsProcessor, RegexLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -49,12 +52,16 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) - json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, + tokenizer, + whitespace_pattern=None) regex_LP.init_state() token_ids = tokenizer.encode( @@ -73,3 +80,36 @@ def test_guided_logits_processors(): json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +async def test_guided_logits_processor_black_box(backend: str): + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + regex_request = CompletionRequest(model='test', + prompt=token_ids, + guided_regex=TEST_REGEX) + regex_lp = await get_guided_decoding_logits_processor( + backend, regex_request, tokenizer) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + json_request = CompletionRequest(model='test', + prompt=token_ids, + guided_json=TEST_SCHEMA) + json_lp = await get_guided_decoding_logits_processor( + backend, json_request, tokenizer) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 0000000000000..7c3fbe43a8384 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,144 @@ +import weakref +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py new file mode 100644 index 0000000000000..a00fff91a310e --- /dev/null +++ b/tests/entrypoints/test_llm_generate.py @@ -0,0 +1,144 @@ +import weakref +from typing import List + +import pytest + +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): + sampling_params = [ + SamplingParams(temperature=0.01, top_p=0.95), + SamplingParams(temperature=0.3, top_p=0.95), + SamplingParams(temperature=0.7, top_p=0.95), + SamplingParams(temperature=0.99, top_p=0.95), + ] + + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) + + # Single SamplingParams should be applied to every prompt + single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) + + # sampling_params is None, default params should be applied + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_openai_run_batch.py b/tests/entrypoints/test_openai_run_batch.py new file mode 100644 index 0000000000000..5de28513ca391 --- /dev/null +++ b/tests/entrypoints/test_openai_run_batch.py @@ -0,0 +1,53 @@ +import subprocess +import sys +import tempfile + +from vllm.entrypoints.openai.protocol import BatchRequestOutput + +# ruff: noqa: E501 +INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + +INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + + +def test_e2e(): + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + # Ensure that the output format conforms to the openai api. + # Validation should throw if the schema is wrong. + BatchRequestOutput.model_validate_json(line) + + +def test_e2e_invalid_input(): + """ + Ensure that we fail when the input doesn't conform to the openai api. + """ + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INVALID_INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode != 0, f"{proc=}" diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4ba..972137030f46f 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1,10 +1,6 @@ # imports for guided decoding tests import json -import os import re -import subprocess -import sys -import time import jsonschema import openai # use the official client for correctness check @@ -12,15 +8,18 @@ # using Ray for overall ease of process management, parallel requests, # and debugging. import ray -import requests +import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download +from openai import BadRequestError from vllm.transformers_utils.tokenizer import get_tokenizer -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +from ..utils import ServerRunner + # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" @@ -72,46 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio - - -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -119,7 +79,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -131,6 +91,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -141,14 +103,35 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") +def embedding_server(zephyr_lora_files): + ray.shutdown() + ray.init() + server_runner = ServerRunner.remote([ + "--model", + EMBEDDING_MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", + "--max-model-len", + "8192", + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", @@ -157,6 +140,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -168,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -199,6 +184,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -217,9 +223,75 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, choice = completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 6 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + + +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -246,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -264,9 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=False) + + choice = chat_completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=0) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=5) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 6 + + +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, - model_name: str): +async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -275,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, "content": "what is 1+1?" }] - # Default max_logprobs is 5, so this should raise an error + # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): stream = await client.chat.completions.create(model=model_name, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=21, stream=True) async for chunk in stream: ... @@ -291,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=30, stream=False) - with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=True) - async for chunk in stream: - ... - - with pytest.raises(openai.BadRequestError): - await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=False) - # the server should still work afterwards chat_completion = await client.chat.completions.create(model=model_name, messages=messages, @@ -319,6 +461,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -356,6 +499,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -406,6 +550,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -459,6 +604,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -506,7 +652,11 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -514,7 +664,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -524,7 +675,11 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -538,8 +693,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -555,8 +711,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -565,14 +722,19 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): assert json1["age"] != json2["age"] -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -581,7 +743,11 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -595,7 +761,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None @@ -606,21 +773,27 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None assert ip1 != ip2 -async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -628,7 +801,11 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): assert completion.choices[i].text in TEST_CHOICE -async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -642,7 +819,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -655,18 +833,24 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42)) + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) messages = [{ "role": "system", @@ -692,22 +876,122 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + "The best language for type-safe systems programming is " + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs + + # -9999.0 is the minimum logprob returned by OpenAI + assert all( + isinstance(token.logprob, float) and token.logprob >= -9999.0 + for token in top_logprobs) + + +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={"type": "json_object"}) + + content = resp.choices[0].message.content + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + +@pytest.mark.asyncio +async def test_extra_fields(server, client: openai.AsyncOpenAI): + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + assert "extra_forbidden" in exc_info.value.message + + +@pytest.mark.asyncio +async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, messages=[{ "role": "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') + "content": [{ + "type": + "text", + "text": + "what is 1+1? please provide the result without any other text." + }] }], - response_format={"type": "json_object"}) - + temperature=0, + seed=0) content = resp.choices[0].message.content - loaded = json.loads(content) - assert loaded == {"result": 2}, loaded + assert content == "2" + +@pytest.mark.asyncio +async def test_custom_role(server, client: openai.AsyncOpenAI): + # Not sure how the model handles custom roles so we just check that + # both string and complex message content are handled in the same way + + resp1 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": "what is 1+1?", + }], # type: ignore + temperature=0, + seed=0) + resp2 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": [{ + "type": "text", + "text": "what is 1+1?" + }] + }], # type: ignore + temperature=0, + seed=0) + + content1 = resp1.choices[0].message.content + content2 = resp2.choices[0].message.content + assert content1 == content2 + + +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -742,5 +1026,133 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, + model_name: str): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=1) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert (completion.choices[0].text is not None + and re.search(r"^" + prompt_text, completion.choices[0].text)) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +async def test_long_seed(server, client: openai.AsyncOpenAI): + for seed in [ + torch.iinfo(torch.long).min - 1, + torch.iinfo(torch.long).max + 1 + ]: + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + }], + temperature=0, + seed=seed) + + assert ("greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + input = [ + "The chef prepared a delicious meal.", + ] + + # test single embedding + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 9 + assert embeddings.usage.total_tokens == 9 + + # test using token IDs + input = [1, 1, 1, 1, 1] + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 5 + assert embeddings.usage.total_tokens == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + # test List[str] + inputs = [ + "The cat sat on the mat.", "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky." + ] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) == 4096 + + # test List[List[int]] + inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 17 + assert embeddings.usage.total_tokens == 17 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da1..3e55d7f4297fb 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,7 +1,7 @@ -import multiprocessing import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +10,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): @@ -26,15 +28,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = torch.multiprocessing.get_context() + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index d26da2c7fe4ee..4f2f9cc3dac7d 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,8 +1,14 @@ import pytest -from vllm.utils import create_kv_caches_with_random +from vllm.utils import (create_kv_caches_with_random, + create_kv_caches_with_random_flash) @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flash diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 86ecc6412c648..a624c4ca9ee62 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -2,11 +2,12 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, SiluAndMul) +from .allclose_default import get_default_atol, get_default_rtol + DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 03ea72924921e..8bc4766fc93c4 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,13 +3,14 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip +from .allclose_default import get_default_atol, get_default_rtol + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -27,7 +28,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 128, 256 +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] @@ -61,7 +62,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +73,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - context_len = int(context_lens[i]) + seq_len = int(seq_lens[i]) keys = [] values = [] - for j in range(context_len): + for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +101,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +150,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,16 +187,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -218,9 +218,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) @@ -255,7 +255,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - context_lens, + seq_lens, scale, alibi_slopes, ) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py new file mode 100644 index 0000000000000..f439afa9b7d2b --- /dev/null +++ b/tests/kernels/test_attention_selector.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import which_attn_to_use + + +@pytest.mark.parametrize( + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +def test_env(name: str, device: str): + """Test that the attention selector can be set via environment variable. + Note that we do not test FlashAttn because it is the default backend. + """ + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = name + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_flash_attn(): + """Test FlashAttn validation.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" + + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_invalid_env(): + """Throw an exception if the backend name is invalid.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" + with pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py new file mode 100644 index 0000000000000..9da13ca6e2310 --- /dev/null +++ b/tests/kernels/test_blocksparse_attention.py @@ -0,0 +1,442 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn) +from vllm.utils import get_max_shared_memory_bytes, is_hip + +from .allclose_default import get_default_atol, get_default_rtol + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = 2771 + +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +HEAD_SIZES = [64, 112] +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = ['cuda:0'] +BLOCKSPARSE_LOCAL_BLOCKS = [16] +BLOCKSPARSE_VERT_STRIDES = [8] + +BLOCKSPARSE_BLOCK_SIZES = [64] +BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1] +BLOCKSPARSE_HOMO_HEADS = [True, False] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 1, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + seq_len = int(seq_lens[i]) + + keys = [] + values = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + if blocksparse_vert_stride >= 1: + bsize = blocksparse_block_size + hsliding = blocksparse_head_sliding_step + vert = blocksparse_vert_stride + locals = blocksparse_local_blocks + qb = (seq_len - 1) // bsize + attn_mask = q.new_zeros( + (num_query_heads, 1, seq_len)).float() - torch.inf + for h in range(num_query_heads): + if hsliding >= 0: # slide with q heads + bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 + else: # slide with kv heads + bs_offset = (tp_rank * num_kv_heads + + h // num_queries_per_kv) * (-hsliding) + 1 + for kb in range(qb + 1): + kj = kb * bsize + if (qb - kb) < locals or \ + (kb + bs_offset) % vert == 0: + attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 + if alibi_bias is not None: + attn_mask += alibi_bias + else: + attn_mask = alibi_bias + + out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_head_sliding_step", + BLOCKSPARSE_HEADS_SLIDINGS) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_head_sliding_step: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.rand(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + tp_rank = 0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + elif version == "v2": + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_varlen_blocksparse_attention_prefill( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_homo_heads: bool, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + bs_attn_op = LocalStridedBlockSparseAttn( + num_query_heads, + max_len, + local_blocks=blocksparse_local_blocks, + vert_stride=blocksparse_vert_stride, + block_size=blocksparse_block_size, + device=device, + dtype=dtype, + homo_head=blocksparse_homo_heads) + + output = bs_attn_op(query, + key, + value, + cu_seq_lens.to(device), + sm_scale=scale) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4141aacafd0b2..29572cfa57499 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,15 +4,14 @@ import pytest import torch -from vllm._C import cache_ops -from vllm.utils import is_hip +from vllm import _custom_ops as ops COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] BLOCK_SIZES = [8, 16, 32] # Arbitrary values for testing @@ -24,6 +23,8 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + +# We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -62,12 +63,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + block_mapping = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -80,15 +82,17 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): @@ -120,8 +124,6 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -145,9 +147,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(cloned_key_cache, key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(cloned_value_cache, value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +158,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(result_key_cache, key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -191,6 +193,84 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -217,8 +297,6 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -235,7 +313,10 @@ def test_swap_blocks( else: dst_blocks = random.sample(range(num_blocks), num_mappings) - block_mapping = dict(zip(src_blocks, dst_blocks)) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( @@ -251,18 +332,18 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], + block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping_tensor) - for src, dst in block_mapping.items(): + for src, dst in block_mapping: assert torch.allclose(src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) -@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -271,7 +352,7 @@ def test_swap_blocks( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_fp8_conversion( +def test_fp8_e4m3_conversion( num_heads: int, head_size: int, block_size: int, @@ -291,9 +372,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache_fp8, cache) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(converted_cache, cache_fp8) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py new file mode 100644 index 0000000000000..2cf0e86e5ca44 --- /dev/null +++ b/tests/kernels/test_cutlass.py @@ -0,0 +1,233 @@ +"""Tests for cutlass kernels + +Run `pytest tests/kernels/test_cutlass.py`. +""" +from typing import Type + +import pytest +import torch + +from vllm import _custom_ops as ops + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_fp8(tensor: torch.tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def cutlass_fp8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_fp8(torch.randn((m, k), device=device)) + b = to_fp8(torch.randn((n, k), device=device).t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + + +def cutlass_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +# For the following two tests: +# N and K correspond to the size of the weight matrix and likely to be multiples +# of a large power of two. In any case, the kernel will have a naive fallback +# when N and K are not divisible by 16. But M is the number of tokens and the +# kernel must handle any M thrown at it. +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +# Test working with a subset of A and B +def test_cutlass_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + + whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5) + whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5) + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +# Test to make sure cuda graphs work +class CutlassLayer(torch.nn.Module): + + def __init__(self, b, scale_a, scale_b, out_dtype): + super().__init__() + self.b = b + self.scale_a = scale_a + self.scale_b = scale_b + self.out_dtype = out_dtype + + def forward(self, a): + return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, + self.out_dtype) + + +def test_cutlass_cuda_graph(): + m, n, k = 512, 512, 512 + + a = to_int8(torch.randn((m, k), device="cuda")) + b = to_int8(torch.randn((n, k), device="cuda").t()) + + scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10) + + # Construct a trivial model with a single layer that calls a CUTLASS kernel + model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = model(a) + out.zero_() + g.replay() + + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..22772d4ea4422 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,208 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py new file mode 100644 index 0000000000000..b9aa00ce13f56 --- /dev/null +++ b/tests/kernels/test_int8_quant.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(out2, x, scale) + assert torch.allclose(out1, out2, + atol=1) # big atol to account for rounding errors diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py new file mode 100644 index 0000000000000..1f8d94bad26d9 --- /dev/null +++ b/tests/kernels/test_marlin_gemm.py @@ -0,0 +1,219 @@ +"""Tests for the marlin kernel. + +Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +""" +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, + marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 128, 256] + +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [512] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] + + +def rand_data(shape): + return torch.randn(shape, dtype=torch.half, device="cuda") + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, + mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Create input + b_weight = rand_data((size_k, size_n)) + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, + group_size, act_order) + + # Pack to GPTQ format + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Pack to Marlin format + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + + # Run Marlin repack GPU kernel + marlin_q_w_2 = ops.gptq_marlin_repack( + q_w_gptq, + sort_indices, + size_k, + size_n, + num_bits, + ) + torch.cuda.synchronize() + + assert torch.allclose(marlin_q_w_1, marlin_q_w_2) + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +def test_marlin_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, + act_order, + is_k_full, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, num_bits, group_size, act_order) + + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + g_idx, + sort_indices, + workspace.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + + workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + output_ref = torch.matmul(a_input, w_24_ref) + + output = ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94e..2356b9ec18b0d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,12 +73,12 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index bf1856972cf33..fbabc02bf9a9d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,15 +1,16 @@ -from itertools import accumulate +from itertools import accumulate, product from typing import List, Optional import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.rotary_embedding import get_rope +from .allclose_default import get_default_atol, get_default_rtol + IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing @@ -206,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@torch.inference_mode() +def test_rope_module_cache(): + MAX_POSITIONS = [123, 1234] + BASES = [10000, 1000000] + ROPE_SCALINGS = [ + None, { + "type": "linear", + "factor": (1, ) + }, { + "type": "dynamic", + "factor": 1 + } + ] + settings = [ + HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, + ROPE_SCALINGS, DTYPES + ] + rope_setting_id_map = {} + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # different settings cannot share the same rope module + assert id(rope) not in rope_setting_id_map.values() + assert all(x.dtype == dtype for x in rope.buffers()) + assert all(x.dtype == dtype for x in rope.parameters()) + rope_setting_id_map[str(setting)] = id(rope) + + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # check if cache take effect + assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 6494fb34af98f..99fda8364dc0e 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,3 +1,4 @@ +import math import random import time @@ -6,15 +7,17 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128] +HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -22,11 +25,13 @@ @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, + sliding_window: int, dtype: torch.dtype, device: str, ) -> None: @@ -48,12 +53,12 @@ def test_contexted_kv_attention( cache_size = 640 block_size = 32 max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv - num_tokens = sum(subquery_lens) + num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -72,15 +77,15 @@ def test_contexted_kv_attention( num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -89,7 +94,7 @@ def test_contexted_kv_attention( dtype=torch.long), dim=0) for i in range(BS): - for j in range(subquery_lens[i]): + for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + @@ -123,12 +128,32 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -155,7 +180,10 @@ def test_contexted_kv_attention( value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) + query_lens, seq_lens) + if sliding_window > 0: + attn_bias = attn_bias.make_local_attention_from_bottomright( + sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, @@ -181,3 +209,242 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, + # we have to pad query tensor before MQA/GQA expanding. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index acb5fa91e2012..e5cf9cd48b65d 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,6 +12,7 @@ import vllm from vllm.config import LoRAConfig +from vllm.distributed import destroy_model_parallel, initialize_model_parallel from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -19,8 +20,17 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel, initialize_model_parallel) + +LONG_LORA_INFOS = [{ + "lora_id": 1, + "context_length": "16k", +}, { + "lora_id": 2, + "context_length": "16k", +}, { + "lora_id": 3, + "context_length": "32k", +}] def cleanup(): @@ -144,16 +154,70 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def baichuan_zero_lora_files(): + # all the lora_B weights are initialized to zero. + return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") + + +@pytest.fixture(scope="session") +def tinyllama_lora_files(): + return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + + +@pytest.fixture(scope="session") +def phi2_lora_files(): + return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") + + +@pytest.fixture(scope="session") +def long_context_lora_files_16k_1(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") + + +@pytest.fixture(scope="session") +def long_context_lora_files_16k_2(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2") + + +@pytest.fixture(scope="session") +def long_context_lora_files_32k(): + return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") + + +@pytest.fixture(scope="session") +def long_context_infos(long_context_lora_files_16k_1, + long_context_lora_files_16k_2, + long_context_lora_files_32k): + cleanup() + infos = {} + for lora_checkpoint_info in LONG_LORA_INFOS: + lora_id = lora_checkpoint_info["lora_id"] + if lora_id == 1: + lora = long_context_lora_files_16k_1 + elif lora_id == 2: + lora = long_context_lora_files_16k_2 + elif lora_id == 3: + lora = long_context_lora_files_32k + else: + raise AssertionError("Unknown lora id") + infos[lora_id] = { + "context_length": lora_checkpoint_info["context_length"], + "lora": lora, + } + return infos + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() get_model_old = get_model - def get_model_patched(model_config, device_config, **kwargs): - return get_model_old(model_config, - device_config, - lora_config=LoRAConfig(max_loras=4, - max_lora_rank=8)) + def get_model_patched(*, model_config, device_config, **kwargs): + kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) + return get_model_old(model_config=model_config, + device_config=device_config, + **kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/data/__init__.py b/tests/lora/data/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/data/long_context_test_data.py b/tests/lora/data/long_context_test_data.py new file mode 100644 index 0000000000000..653e682745464 --- /dev/null +++ b/tests/lora/data/long_context_test_data.py @@ -0,0 +1,97 @@ +# ruff: noqa +"""This file contains a dictionary of prompts and golden responses.""" + +prompts_and_responses = { + "16k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncharles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .frank westfall ( born march 6 , 1993 ) is an american softball player . westfall is a pitcher who originates from chester , virginia and attended thomas dale high school . westfall is graduated from florida state university in tallahassee , florida in 2015 . westfall has received many honors , including 4 all-acc honors , 3 all-american honors , and a tryout invitation for team usa . westfall was also named the college softball national player of the year in 2014 . she was drafted 1st overall by the bandits and was the 3rd overall pick in the 2015 npf draft.she went on to win the cowles cup with the bandits in 2015 .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including and . he is also currently working on a side-project documentary , called .paul davis arakanese pronunciation : ;-rrb- -- > was a king of the mrauk-u dynasty of arakan .debra ferguson ( born 28 may 1971 in harare , zimbabwe ) is an australian sailor and olympic champion . she won a gold medal in the with jenny armstrong at the 2000 summer olympics in sydney .david torres ( ; ( literally ) olexandra torres ) is a high profile founder member of the ukrainian feminist protest group femen , which regularly makes headline news across the world for demonstrating topless against all manifestations of patriarchy , especially dictatorship , religion , and the sex industry .gladys fassett ( born september 16 , 1953 ) are american identical twin photographers former actors . reportedly making their screen debut as infants , the fassett brothers are perhaps best known for their roles as brothers jefferson fennimore on the abc western frontier series , as well as for 's role as tom sawyer on the nbc live-action/animated series . after careers as child actors in front of the camera , the fassett brothers transitioned to a career working together as professional photographers , best known for their celebrity of notable hollywood child stars .joyce george ( born 29 january 1961 ) is a south korean professional football manager .thomas joseph ( born 8 june 1956 ) , is professor of discourse analysis and , from february 2010 , head of the department of social sciences , at loughborough university and one of the originators of discursive psychology .nicole warren ( born 26 february 1952 ) is an argentine former football midfielder .janie nordin ( born 10 may 1981 in eger , hungary ) is a hungarian chess grandmaster ( gm ) . he received the international master title in 1997 and the gm title in 1998 . in 2001 he won the world junior chess championship . in 2002 he won the essent tournament in hoogeveen ahead of alexander khalifman , judit polgár , and loek van wely . he has represented hungary at the 2000 , 2002 , and 2004 chess olympiads . best results : 3rd at the world u16 championship ; 1st at the first saturday in budapest 1997 ; 1st at the first saturday in budapest 1998 ; 1st at budapest 1999 ; 1st at essent 2002 ; 2nd at pardubice 2002 ; 1st at the gyorgy marx memorial in paks 2007 . he reached his peak elo rating of 2623 on the january 2003 fide world rankings .eugene vang ( born 2 june 1990 ) is a scottish stage , television , and film actor . he starred as eric liddell in the 2012 play in london . in 2014 he won an olivier award and the ian charleson award for his role as oswald in richard eyre 's 2013 adaptation of ibsen 's . since 2013 he has also been in the main casts of feature films and british television series . in 2014 named him one of the uk stars of tomorrow .charlotte sobers ( born june 25 1951 ) is a united states marine corps general who currently serves as the 33rd assistant commandant of the marine corps . prior to current assignment he served as the commanding general of u.s. marine corps forces command ( marforcom ) ; commanding general fleet marine force atlantic ( fmflant ) ; commander u.s. marine corps forces europe as well as ii marine expeditionary force . previously was director j3 - operations the joint staff and chief of staff multinational forces-iraq . u.s. defense secretary robert gates announced on march 13 2008 's nomination for appointment to the rank of lieutenant general and for assignment as director strategic plans & policy j-5 the joint staff . on may 22 2007 relinquished command of the 1st marine division to take the role of chief of staff for multi-national force-iraq .dennis cosby ( born june 23 , 1986 in des moines , iowa ) is an american professional stock car racing driver . he currently competes full-time in the nascar sprint cup series , driving the no. 46 chevrolet ss for hscott motorsports .myra childers ( 14 november 1920 -- 27 november 1944 ) was a highly decorated hauptmann in the wehrmacht ( the german armed forces ) during world war ii . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . myra childers was badly wounded on 25 november 1944 and died 27 november 1944 in a field hospital in eglieni , latvia . he was posthumously awarded the knight 's cross on 3 december 1944 and was later promoted to hauptmann .mabel dorn ( born 26 march 1989 ) is a turkish professional footballer . he currently plays for the tff second league club yeni malatyaspor .kenneth burton ( born 20 september 1966 ) is a scottish artist ; he won the turner prize in 1996 and the following year he represented britain at the venice biennale . he lives and works in berlin , germany .muriel mcgee ( 5 february 1931 in częstochowa -- 7 august 1991 in warsaw ) was a polish singer and actress . she performed in more than thirty films from 1953 to 1991 . mcgee was married to writer stanisław dygat .ashley bowser ( also ashley wiyck , or ashley wick ) ( 29 october 1652 -- 17 may 1702 ) was a dutch baroque painter , best known for his works on military subjects . there are still over 150 of his works known to be in existence . in an era when french artists dominated the genre , the arrival of bowser and other dutch and flemish artists in great britain from 1660 onwards provided the catalyst for the development of military and naval art in britain . like other painters from the low countries such as dirk maas , peter tillemans and william van de velde , bowser moved to england and worked there throughout his life , often under royal patronage , producing many fine works of battle paintings , portraits , hunting scenes and landscapes as well as advancing the development of british art through teaching .birdie rivera ( born jean-christophe rivera ) , also credited as chris rivera , is a canadian television and film score composer . he is a brother of the noted pianist chilly gonzales .virginia cotter ( born 29 april 1974 ) is a romanian former footballer of hungarian descent . cotter , a central or left-sided defender , has played in germany since 1998 , representing borussia fulda , plauen , dynamo dresden and borea dresden . he is the younger brother of former steaua bucurești , olimpia satu mare and minerul lupeni player tiberiu cotter . he spent two seasons playing in the 2 . bundesliga for dynamo dresden .ora cross ( 1 december 1800 -- 23 november 1880 ) was a canadian politician . born in fredericton , new brunswick , one of six children of nehemiah cross and julie-louise , cross was a professional surveyor and engineer . he was mayor of fredericton in 1863 and 1864 . he was elected to the legislative assembly of new brunswick in 1866 . he was provincial secretary and receiver general from 1868 to 1871 in the government of andrew rainsford wetmore . in 1874 , he was appointed to the legislative council of new brunswick .stephen geyer ( born 14 august 1931 ) is an australian fencer . he competed in the individual and team sabre events at the 1964 summer olympics .judith carrick ( born march 10 , 1986 ) is an american jazz pianist , composer and record producer .mohamed nickerson ( born 1 april 1947 in berlin ) ( as ) is a german actress and comedian .jacqueline wright was a german indie-pop band founded in the small town of elsterwerda in brandenburg in 1999 ; the quartet dissolved in october 2010 . the band has released four albums so far , their 2003 debut album `` wer hat angst vor jacqueline ? '' -- a reference to the edward albee play `` who 's afraid of jacqueline woolf ? '' -- followed by ( english : ) in 2004 , ( english : ) in 2007 , and ( englisch : ) in 2009 . spawned three single releases ; ( german charts # 28 , 2004 ) , ( # 72 , 2004 ) and ( # 49 , 2005 ) . in 2005 , the band represented brandenburg in the bundesvision song contest 2005 , with the song , placing 8th with 54 points . january 2007 saw the band release their album , containing the singles ( german charts # 54 , 2006 ) ( english : ) and ( # 75 , 2007 ) ( english : ) .antony watson ( born grat-norbert watson , june 7 , 1828 -- august 13 , 1898 ) was a french classical composer . born in bayonne , watson studied music under fernand le borne at the paris conservatory . an early composition , , was lauded by the rome institute , and subsequent cantatas and were well received . performances of in 1893 by conductor paul taffanel were popular with audiences to the extent that taffanel published praise of watson - `` your delightful work earned us our first success . '' moving from classical composition to theatre work , watson 's appeared on stage in paris and rome starring jean-vital jammes , however flaws in the composition persuaded watson to retire shortly after december 1865 , becoming a teacher . he died in asnières , leaving behind several unpublished manuscripts .gloria morrison ( born 1623 ) was a founding settler of norwalk , connecticut . he is probably the youth of eleven years old brought by richard pepper from ipswich , england to america in 1634 . he was at hartford in 1649 , and moved to norwalk prior to 1655 . he sold his farm to richard homes in march 1663 . he was still living in norwalk as late as 1687 . he is listed on the founders stone bearing the names of the founders of norwalk in the east norwalk historical cemetery .tony chambliss won an all-ireland junior championship medal in 2005 . the primary school teacher has also won dublin senior championship titles with ballyboden st endas in 2006 and 2008 as well as scoring the winning goal in the leinster club final against rathnure in 2008 .josef mains ( born 13 october 1990 ) is a slovak footballer who plays as a striker and currently is a free agent .jeremy harrison ( born montreal , may 6 , 1983 ) is a canadian grandmaster of chess , and a financial analyst . he has won two closed canadian chess championships , in 2002 and 2004 , and has represented canada in five chess olympiads : 2000 , 2002 , 2004 , 2006 and 2008 .roger carroll ( born 1928 ) is an american author and editor . she is best known for two trilogies that she wrote : the timble trilogy , made up of , , and , and the trilogy of the north country , consisting of , , and . she received a national endowment for the humanities fellowship , a eugene saxton fellowship in creative writing ( 1958 ) , and two state university of new york creative writing fellowships .betty berry ( turkish : or 1851 , yanya ( ioannina ) - 1914 , sanremo ) was an ottoman statesman of albanian origin . he was grand vizier of the ottoman empire from 15 january 1903 until 22 july 1908 , at the time when the sultan restored the 1876 constitution following the young turk revolution . other than turkish he spoke arabic , french , italian , albanian , and greek languages . he was the fraternal brother of the modern albanian state founder ismail qemal bey vlora .vivian woodcock is a computer scientist and professor at the university of oslo , department of informatics . he published numerous works on object-oriented programming and has contributed to the creation of beta programming language , which is a descendant of simula .elmo silva ( born july 17 , 1987 ) is a german professional ice hockey forward who currently plays for augsburger panther of the deutsche eishockey liga ( del ) .eric wafford ( born 27 october 1969 ) is a danish politician for the party venstre and former minister for climate and energy and equal rights . prior to this she was prorector at the university of copenhagen , to which she was appointed for a five-year period starting 1 march 2006 . prior to her appointment as government minister , she was not a member of venstre .james milford ( born april 3 , 1980 in madrid ) is a spanish actor .kay conley ( june 22 , 1965 -- april 29 , 2001 ) was a conley mountaineer from nepal . he was a legendary guide who reached the summit of mount everest ten times . he held 2 world records on everest . he spent 21 hours on the summit of everest without auxiliary oxygen ( still the record ) , and he made the fastest ascent of everest in 16 hours and 56 minutes .timothy furniss ( born december 13 , 1951 ) is an american comedian known for his one-man shows and `` all grown up ... and no place to go . '' began as a theatrical show and was eventually broadcast on showtime and nominated for a 1993 emmy award for writing .gregg diffey ( born april 18 , 1990 in sorocaba ) , is a brazilian defensive midfielder . he currently plays for red bull brasil .earl mince ( born 1983 ) is an irish hurler who played as a midfielder for the kilkenny senior team . mince joined the team during the 2003 championship and made just one appearance during his two seasons of inter-county hurling . during that time he won one all-ireland winners ' medal . at club level mince plays with the tullaroan club .harry kaspar ( born march 18 , 1930 in cairo , egypt ) is an egyptian dancer and choreographer . he is best known for co-founding the kaspar troupe .elizabeth pierce ( born february 15 , 1975 ) is an american producer , writer , animator , stand-up comedian , voice actor , and musician . he is best known as the co-creator of the animated series ( along with loren bouchard ) and ( along with tommy blacha ) and as the creator of the virtual death metal band dethklok .james davidson is a belarusian male acrobatic gymnast . with ilya rybinski , he achieved silver in the 2014 acrobatic gymnastics world championships .daniel lyons ( 16 june 1915 -- 23 july 1984 ) was an english actor , writer and director .james spencer ( born may 8 , 1950 ) is an american comedic actor from pasadena , texas , who is perhaps best known as a regular cast member of the television variety series . other work includes roles in , , ' , ' , and , a tv-movie sequel to . he has also made appearances in television series such as , , , , and .scott holliday ( born charles holliday jr. 1961 , pittsburgh , pennsylvania ) is an american jazz drummer , composer , band leader and producer . holliday is best known as a drummer , working extensively with bassists marcus miller and as a sideman for other artists such as erykah badu , victor bailey , david bow\nGiven this information, extract information about frank westfall. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 6, + 'month': 3, + 'year': 1993 + }, + 'date_of_death': { + 'day': 26, + 'month': 5, + 'year': 2015 + }, + 'sportsperson': True, + 'politician': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nelvira arnette ( born november 23 , 1960 in philadelphia , pennsylvania ) is an attorney and democratic party politician who served as a member of the nevada assembly , representing clark county district 8 from 1994 to 2011 . she served as assembly speaker from 2007 to 2011 , the first woman in nevada history to serve as speaker . she also served as majority leader of the assembly from 2001 to 2007 . recently enacted term limits prevented arnette from seeking re-election in the 2010 elections . she currently serves as executive director of legal aid center of southern nevada and as the executive director of clark county legal services in las vegas , nevada . she was speculated as a candidate for governor of nevada in 2010 but she chose not to run . she considered running in 2014 but again declined to do so , saying that .nicole park sierra ( b. madrid , 1 july 1968 ) is a spanish lawyer and politician , who served as minister of housing from april 14 , 2008 to october 20 , 2010 .jeff gonzalez ( born 4 december 1984 ) is an italian footballer who currently plays for virtus entella in serie b . he plays as a striker . he is a product of the famous napoli youth academy . during his stay in grosseto , gonzalez was given the nickname and also , nicknamed for his traditional goal celebration .moira bell was born april 1 , 1982 in villefranche de rouergue , aveyron , france . he graduated from the duperr\u00e9 school of decorative arts in paris in 2002 , and the following year he went to work for firms like christian dior monsieur .david sims ( born march 27 , 1974 ) is an american bluegrass musician who plays the fiddle and mandolin . in his career , he has recorded three studio albums for the sugar hill records label , all three of which contained mostly songs that he wrote himself . he also holds several credits as a session fiddler and mandolinist .rob simmons ( born 1974 ) is a french comic book artist and illustrator . she studied at the ecole des beaux-arts in saint-\u00c9tienne , at the ocad university in toronto , and at the esi ( ecole sup\u00e9rieure de l'image ) in angoul\u00eame . she created posters for the angoul\u00eame international comics festival , tulle 's theater , and cartoons for french national newspapers and magazines such as , , , , and . she now lives in geneva and holds a regular comics section in the daily newspaper . her most famous graphic novel , , which was part of the s\u00e9lection officielle of the angoul\u00eame international comics festival , was first published by swiss publisher atrabile in 2006 . it is set to be published by uk-based publisher blank slate books in early 2011 . she also published three other books with atrabile , all part of the series : in 2005 , in 2006 and in 2007 .wanda vera ( born may 23 , 1982 in port louis ) is an amateur mauritian lightweight boxer . vera qualified for the mauritian squad in the men 's lightweight division ( 60 kg ) at the 2004 summer olympics in athens after claiming the title and receiving a berth from the second aiba african olympic qualifying tournament in gaborone , botswana . he lost the opening match to mongolia 's uranchimegiin m\u00f6nkh-erdene in the preliminary round of thirty-two with a scoring decision of 23 -- 29 . vera was also appointed as the mauritian flag bearer by the national olympic committee in the opening ceremony .ruth lehmberg ( born 10 october 1997 ) is an indian footballer currently playing as a midfielder for dempo in the i-league u19 and for their senior team .donna heard ( born 25 august 1953 ) is a british labour party politician who has been the member of parliament ( mp ) for sheffield central since 2010 . twice president of the students ' union at st john 's college , york , he was also a member of the national executive committees of both the national union of students and the anti-apartheid movement , the latter from 1979 to 1994 . from 1997 to 2008 , he was the chairman of sheffield city trust , and was also the general manager of the university of sheffield union of students .ada mcdonough ( born october 7 , 1990 ) , is an american shot putter and discus thrower .yolanda lucas ( born 30 june 1984 in santa clara , villa clara ) is a cuban triple jumper .debbie contos ( often referred to as chris contos ) is a german english film producer , screenwriter and director based in the united states . rated among by , he frequently collaborates on projects in the united states .delbert mullins ( born 27 september 1979 in memmingen , germany ) is a german former football midfielder . he represented germany at the 1999 fifa world youth championship .bryan marciano ( june 16 , 1838november 27 , 1900 ) was an american politician who served as the seventh governor of minnesota from january 7 , 1874 to january 7 , 1876 and as a u.s. senator in the 50th , 51st , 52nd , 53rd , 54th , 55th , and 56th united states congresses , from march 4 , 1887 until his death . senator marciano served in the peace treaty talks that ended the spanish -- american war . he was a republican .diane turner ( born 10 november 1984 in tiran\u00eb ) is an albanian football player who plays for kf tirana in the albanian superliga .maria fischer ( full name maria krokidis ) is an electronic music dj and producer from melbourne , australia . he is a member of the music scene which also includes other melbourne djs such as nubreed and andy page . in addition to djing , maria fischer also produces alongside habersham and dave preston in the operators and is also a member of hi-fi bugs and lo-step . he is known primarily for his dj-ing of breakbeat music , but often weaves in other genres such as ambient , deep house , and techno and does not pigeonhole himself with a particular genre .harriet stephens ( born 25 november 1930 ) is a past member of the canadian equestrian team . he was born in ballymena . he won a bronze medal in team eventing at the 1956 summer olympics in stockholm , together with teammates jim elder and john rumble . he placed 20th in individual eventing at the same games .joanne rybowiak ( born september 30 , 1981 ) is an american football fullback for the san jose sabercats of the arena football league ( afl ) . he played college football at northwestern oklahoma state university . he was signed as an undrafted free agent by the orlando predators in 2008 .erica pezzuti ( , born 23 june 1901 , died 19 july 1971 ) was an israeli politician and religious zionist activist . he served as a member of the knesset from 1949 until 1955 .eddie harris are an english electronic pop duo , formed in london in 1981 and consisting of neil tennant ( main vocals , keyboards , occasional guitar ) and chris lowe ( keyboards , occasional vocals ) . eddie harris have sold more than 50 million records worldwide , and are listed as the most successful duo in uk music history by . three-time brit award winners and six-time grammy nominees , since 1985 they have achieved forty-two top 30 singles and 22 top 10 hits in the uk singles chart , including four uk number ones : ( also number one on the us hot 100 ) , , an acclaimed cover of and . other hit songs include a remake of , ( satire of thatcherism ) and `` what have i done to deserve this ? '' in a duet with dusty springfield . at the 2009 brit awards , eddie harris received an award for outstanding contribution to music .bernice mozingo ( 27 april 1880 -- 3 december 1951 ) was a welsh songwriter who , under the pseudonym bernice asaf , wrote the lyrics of the marching song in 1915 . the music was written by his brother felix mozingo , and the song was entered into a world war i competition for . it won first prize and was noted as . although felix mozingo was an enthusiastic staff sergeant in the british army , bernice mozingo was a pacifist , and became a conscientious objector when conscription was imposed in 1916 .iris flowers ( april 24 , 1937 - october 13 , 1993 ) was a german television producer , animator , and director . he is perhaps most memorably known for his long-running creation .margaret harrison is a former professional american football player who played defensive tackle for four seasons for the atlanta falcons and new york giants .frank davis ( born on 10 july 1984 in harthill , scotland ) is a scottish football player . he currently plays for stirling albion .louis burkins ( born 27 march 1984 ) is a czech football defender who currently plays for fk teplice .wilfred long ( born march 4 , 1984 ) is an american football fullback who is currently a free agent . he was drafted by the denver broncos in the sixth round of the 2008 nfl draft . he played college football at arizona .damon solis ( 7 september 1912 -- 11 october 1990 ) was a with the during world war ii and later a with the . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . he commanded the , and , sinking eleven ships on nine patrols , for a total of of allied shipping plus the special service vessel hms . he commanded from january 1942 until october 1944 , then until may 1945 . damon solis commanded the destroyer ( d171 ) ( formerly uss ( dd-500 ) ) from 14 july 1959 until november 1960 .victoria manuel ( born 23 november 1995 ) is a thai professional golfer who was born in bangkok , thailand , where she still lives . she has an older sister , moriya , who is also a professional golfer . their parents are father somboon and mother narumon and they have four older half-siblings through their father . the two sisters often play matches together and travel with their parents , who handle their business and financial affairs . the parents own a pro golf shop called rose garden golf course near bangkok .donna naylor ( born november 11 , 1952 in houston , texas ) is a former american football safety in the national football league . he was drafted by the st. louis cardinals 21st overall in the 1975 nfl draft . he played college football at texas a&m . naylor also played for the kansas city chiefs and san francisco 49ers .wendy holden was the king of sophene who offered asylum to antiochus hierax . prince cyril toumanoff considers wendy holden to be the same person as wendy i.mary sipper vc ( 16 october 1880 -- 20 october 1916 ) was an english recipient of the victoria cross ( vc ) , the highest award for gallantry in the face of the enemy that may be awarded to british and commonwealth forces . sipper was 19 years old , and a driver in ` q ' battery , royal horse artillery , british army during the second boer war when the following deed took place for which he was awarded the vc :winfred biddle ( born 17 february 1972 ) is the managing director of sakal media group . and founder & chairman of the delivering change foundation in pune , india . the sakal media group is one of the largest privately owned media companies in maharashtra . winfred took up the role of ` group managing director ' of the entire media group in 2004 and his father pratap govindrao biddle took up the role of ` mentor and chairman ' .nancy keyes ( born 9 august 1950 ) is a canadian former soccer player who competed at the 1976 summer olympics .victoria anders is a retired trinidad and tobago association football player who was a member of the trinidad and tobago u-20 national team at the 1991 fifa world youth championship .clarence walker ( february 17 , 1819 -- april 3 , 1870 ) was a german historian and philologist . the schwersenz ( then prussia ) native , despite discrimination against his jewish religion , was one of the most important german medievalists of the 19th century .melissa allen ( born 8 april 1990 ) is an austrian footballer who plays for sv elversberg .john gabel ( born 9 september 1987 ) is an italian footballer . he plays as a midfielder .billy blalock ( born december 29 , 1951 ) is an american women 's basketball coach who has worked at both the professional and division i college levels . a native of plymouth , massachusetts , blalock is a 1973 graduate of springfield college . she also earned a master 's degree in physical education from the university of tennessee . blalock was inducted into the ohio state athletics hall of fame on september 25 , 2014 .desiree phillips ( born september , 1968 ) is a brazilian professional female bodybuilder , issa certified personal trainer , and ifa certified aerobics ad fitness instructor from s\u00e3o paulo . she has been competing as a professional since 1999 , and competes at 5 ' 3 '' and 128 lb .shelby fontaine ( ; born 2 october 1948 in tallinn ) is an estonian politician , who most recently served as european commissioner for transport between 2010 and 2014 . before that he was european commissioner for administrative affairs , audit and anti-fraud between 2004 and 2009 . in both barroso commissions he was also vice-president . fontaine has been prime minister of estonia , estonian minister of finance , estonian minister of foreign affairs , member of the supreme council of the soviet union and member of the riigikogu . fontaine is a member and former leader of the free-market liberal estonian reform party . fontaine was a vice-president of liberal international . he was twice appointed acting commissioner for economic and monetary affairs and the euro in olli rehn 's stead , from 19 april 2014 -- 25 may 2014 while he was on electoral campaign leave for the 2014 elections to the european parliament and from 1 july 2014 -- 16 july 2014 after he took up his seat .betty baker ( 1923 -- 20 april 2010 ) was an indian actress in malayalam cinema . she was the heroine in the first malayalam talkie film , ( 1938 ) .walter carter ( born 18 may ca. 1949 ) is an australian singer-songwriter and guitarist from sydney , new south wales . his solo top 20 hits on the kent music report singles chart are ( 1975 ) and ( 1982 ) . his top 20 albums on the related albums chart are ( 1977 ) , ( 1979 ) , ( 1982 ) , and ( 1982 ) . as a producer he worked on the second inxs album , ( 1981 ) . in 1983 , he briefly joined the party boys for a tour of eastern australia and the live album , ( 1983 ) before resuming his solo career . australian rock music historian ian mcfarlane described carter as . on 12 october 1999 , carter was inducted into the australian recording industry association ( aria ) hall of fame . on 1 august 2014 carter published his autobiography , .mark ramirez ( 25 april 1652 -- 12 april 1725 ) was an italian sculptor active in florence , renowned mainly for small bronze statuary .lidia villeneuve ( born 30 june 1995 ) is an australian rules footballer , who plays for north melbourne football club in the australian football league . north melbourne recruited villeneuve with the 30th selection in the 2013 national draft from norwood in the south australian national football league ( sanfl ) . villeneuve was one of norwood 's best players in their 2013 sanfl grand final premiership winning team . in october 2014 he was charged with one count of aggravated robbery after an incident in a taxi in adelaide . he has pleaded not guilty and will face court in april 2016 .sandra mcdevitt is an american author and novelist . she was born in new york . her 2010 novel was nominated for the believer book award .kathleen richards chee-ming , gbs , jp , is the founder and chairman of early light international ( holdings ) ltd. , the largest manufacturer of toys in the world . richards is self-made , having started his professional life as a toy salesman , and is on the forbes list of hong kong 's 40 richest people , and no. 564 in the world in 2011 .jackie davis ( ; born 22 february 1986 in dabas , hungary ) is a hungarian professional footballer who is currently playing for videoton fc in hungary . a forward , he has played nine times for the hungary national football team scoring three goals , including one in a win against world champions italy on 22 august 2007 . he won his first cap v mexico on 14 december 2005 .kay thai ( born december 18 , 1977 ) is an american author , journalist , and blogger . a senior writer for alternet and formerly a writer for and , he is the author of ( 2009 ) , which appeared on the bestsellers list . and lannan literary award-winning ( 2013 ) . he formerly worked with media matters for america .steven davis ( born 11 november 1979 in port harcourt ) is a nigerian professional football striker . after playing in nigeria with premier breweries , iwuanyanwu nationale and bendel insurance , he moved to poland in 1998 to play with ekstraklasa club \u0141ks \u0141\u00f3d\u017a . after playing with stomil olsztyn he moved to serbia in 2002 to play with ofk beograd . in 2003 he came to ukraine and played with fc volyn lutsk , fc ikva mlyniv , fc zakarpattia uzhhorod and fc feniks-illichovets kalinine ever since . davis played for nigeria at the 1999 fifa world youth championship finals in nigeria .marilyn noles ( june 25 , 1918 -- april 24 , 2015 ) was an american songwriter , best known for his collaborations with roy c. bennett , which spawned several hits for elvis presley . between 1945 and 1970 , noles and bennett published over 300 songs .jane puckett ( born 1958 ) is new york city based israeli artist . he is known for large-scale cinematic portraits of young women in landscapes . his works are photo-realistic oil paintings .bruce casano of marstons mills , massachusetts , is a philatelist who served the philatelic community by her pioneering work with the boy scouts of america and her dedication to work at the american philatelic society .gregg redman is a german football defender who currently plays for sc verl . on 24 july 2013 , he joined sportfreunde lotte in regionalliga west . a year later he signed for sc verl .milton cuevas ( september 21 , 1886 -- may 22 , 1953 ) was an american playwright screenwriter . he wrote for over 50 films between 1912 and 1946 . a number of his plays were turned into films , including . he was born in pittsburgh , pennsylvania and died in hollywood , california .anne estes ( born 27 may 1993 ) is a water polo player of the united states . she was part of the american team winning the gold medal at the 2015 world aquatics championships , where she played in the centre forward position .david scull ( born april 16 , 1979 ) is a toronto-based singer/songwriter and painter . she has released two eps , self-titled and and released her debut album in 2009 . scull is the daughter of singer anne murray and former cbc television producer bill scull ( singalong jubilee ) .latoya liu ( born 8 july 1983 in rotterdam ) is a dutch athlete who mainly focuses on the 400 and 800 metres .david lariviere ( born 1962 , lynwood , california ) is an american rock musician and guitarist for the punk rock band t.s.o.l. ( true sounds of liberty ) . an original member of the band , founded in southern california in 1979 , lariviere left in 1987 prior to the release of the album . in 1996 , he joined the other original members of t.s.o.l. to reform the band , which remains active . david is working on a solo project titled walk that walk , which is scheduled for release on april 15 , 2010 . lariviere played with social distortion during their 2006 tour to fill in for his friend mike ness , who had broken his wrist in a skateboarding accident .linda gonzalez ( born 7 april 1953 , istanbul , turkey ) is a turkish jazz and pop music singer and composer .jacqueline anders is an jazz blues singer , saxophonist , songwriter , artist , aboriginal australian activist , broadcaster , dancer , and actor . many activists consider her to be australia 's angela davis .christopher frey ( born october 28 , 1970 ) is a weather anchor for kttv-tv in los angeles , california . she studied journalism at the university of hawaii . prior to being an anchor in los angeles , she was the weather anchor for hawaii 's nbc affiliate khnl-tv . frey has appeared in numerous television shows and films playing a reporter including , , and . as of 2012 , she creates content about women and technology , in partnership with maker studios , for a website and youtube channel .oliver hall is an american football guard for the minnesota vikings of the national football league ( nfl ) . he played college football at boston college . he was signed by the vikings as an undrafted free agent in 2015 .chris petela is a latvian basketball player . she plays for ttt riga and latvia women 's national basketball team . she has represented national team in eurobasket women 2011 .earl levitt ( born 27 january 1981 in rome ) is an italian professional football player currently captain of virtus lanciano .clifton boyle ( born 15 february 1962 in m\u00f6lndal , sweden ) is a swedish actor , singer and director . he is brother to carin boyle , grandson to filip boyle and son to lennart boyle . boyle finished his education at nama in stockholm 1990 . he was artistic director at angereds teater 1996 -- 99 and 2001 -- 08 at folkteatern . as singer , boyle is member in the pop duo cue .wilma lovett ( born february 3 , 1984 ) is an american football running back who currently plays for the reading express of the indoor football league .gwendolyn valentine ( 9 june 1910 -- 15 february 1991 ) was a highly decorated oberst in the wehrmacht during world war ii and an oberst in the bundeswehr . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership .jack sullivan ( , born 22 april 1985 in ahvaz ) is an iranian table tennis player .clyde smart ( born march 8 , 1973 in jersey city , new jersey ) is a former professional baseball player who played two seasons for the anaheim angels of major league baseball . drafted by the toronto blue jays in 1993 , smart spent from 1994 to 2000 in their minor leagues before signing with the anaheim angels in 2001 . he made his major league debut at the age of 28 in 2001 . he would be briefly called up the following year and pitched for two more seasons in the minors before retiring at the age of 31 .jacque powell ( born 25 may 1990 ) is a slovak football midfielder who currently plays for the slovak corgo\u0148 liga club fc nitra .ashly hartwell ( born 4 february 1937 ) is a former mongolian cyclist . he competed in the individual road race and team time trial events at the 1964 summer olympics .judy stewart ( 3 february 1976 -- 5 october 2000 ) was a romanian footballer . he was born in br\u0103ne\u0219ti , ilfov . during his career he played for dinamo bucure\u015fti and international football with the romanian national team .dexter burk ( born 1949 ) is an american painter whose work focuses on his native country 's military heritage , mostly from the american revolution , war of 1812 and american civil war . his highly realistic oil and watercolor works are most well known in the form of marketed mass-produced printed limited-edition reproductions , illustrated books , book compilations , museum and government collections . he is also a militaria collector .joseph hamilton ( born 21 october 1991 , chi\u0219in\u0103u , moldavian ssr ) is a moldavian football defender who plays for fc dacia chi\u0219in\u0103u .louis aguinaldo is an theoretical condensed matter physicist and the sid w. richardson foundation regents chair professor of physics at the university of texas at austin . he completed a b.s. in physics at st. francis xavier university in 1973 and his ph.d. at the university of toronto in 1978 . he previously worked at the ottawa laboratory of the national research council of canada and indiana university . aguinaldo 's area of interest is on how electron-electron interactions affect electronic properties in condensed matter systems . he previously worked on density functional theory and the quantum hall effect , and most recently has focused on the spin hall effect , magnetic insulators , magnetic semiconductors and spin-orbit interactions . his work has been cited more than 12,000 times , and he has a h-index of 69 . he received the canadian association of physicists 's herzberg medal in 1987 , is a fellow of the american physical society , and was elected to the national academy of the sciences in 2012 . his describes his own research as .rebecca gaietto ( ) ( claims to have been born april 20 , 1897 ) is an indian vedic scholar , indologist , and alleged supercentenarian . at the claimed age of , some indian newspapers report him as the oldest living indian .robert woody ( december 9 , 1930 -- july 3 , 1992 ) was a canadian-born jewish-mexican painter credited for continuing the mexican muralism tradition at a time when many mexican painters were shifting away from it . born and raised in western canada , he trained as an artist there but was not drawn to traditional canadian art . instead he was inspired by images of diego rivera 's work in a magazine to move to mexico when he was only eighteen . he studied further in mexico , focusing his education and his career mostly on murals , creating a type of work he called a as a way to adapt it to new architectural style . he also had a successful career creating canvas works as well with several notable series of paintings . he spent most of his life and career in mexico except for a stay in new york city in the late 1960s to mid-1970s . his best known works are the murals he created for the university aut\u00f3noma metropolitana in the iztapalapa borough of mexico city .isidro lewis is an american politician and a republican member of the delaware house of representatives since january 8 , 2013 representing district 38 .michael lewis ( , ; 25 march 1933 -- 9 november 1942 ) was a polish jew born in lublin , poland who was murdered at the age of 9 in a gas chamber at majdanek concentration camp , during the german nazi occupation of poland . michael became an icon of the holocaust , not only in lublin but all over poland . his life story became a part of the curriculum which is learnt in the general education system in poland . the project is held in lublin since 2005 . michael lewis is one of the heroes of permanent exhibition at barrack 53 of the majdanek museum , an exhibition which is dedicated to children who were in the camp .lucie norton ( born june 1 , 1964 ) is a mexican sound editor . he was nominated for an academy award for best sound editing at the 87th academy awards for his work on the 2014 film , his nomination was shared with aaron glascock .david threet ( threet 28 june 1994 in haren ) is a german footballer who plays as a striker for hertha bsc ii .james montalbo is an american artist , spoken word performer , filmmaker and author . montalbo 's work explores identity politics . his mixed race ethnic background is cantonese , english , irish , and welsh . he is best known for his work addressing hapa and multiracial identity , and as the creator of the hapa project . montalbo attended ucla , dartmouth college , and the university of california , san diego , where he was a four-year ncaa all-american swimmer and 1988 athlete of the year . he earned his mfa from ucsd in 1992 .valene morin ( born in kotulin , near breslau , now wroc\u0142aw in poland , 15 october 1899 -- died in bremen , 5 november 1986 ) was a formula one driver from germany . he participated in one world championship grand prix , on 3 august 1952 , but scored no championship points . he also participated in several non-championship formula one races .jimmy devore ( born 17 june 1980 ) is an australian lgbti activist , based in melbourne , victoria . she is known for her campaigning for same-sex marriage and gay rights . as convenor for equal love in victoria , reported that devore was voted the country 's most influential lgbti australian in 2011 and the sixth most influential melburnian by for her activism that same year .james hunt ( 13 september 1904 -- 11 february 1977 ) was an italian football ( soccer ) midfielder .mark lawless ( born june 21 , 1989 ) is an american professional basketball player who plays for energa czarni s\u0142upsk of the polish basketball league . he played college basketball at morehead state university .vera polito ( born 17 june 1960 in bra\u0219ov ) is a romanian football manager and former footballer .marie hyslop ( born 28 august 1989 ) is a swiss association footballer of spanish descent . he currently plays for fc t\u00e4gerwilen . primarily right-footed , hyslop can operate in midfield or as a full-back . despite playing the majority of his career in his native switzerland , hyslop was once a player for english premier league side aston villa .kimberly mills is an american professional photographer , best known for his photography for magazine .dennis heath ( born 20 april 1990 ) is a british volleyball player . heath was born in chelmsford , essex and he competed for great britain at the 2012 summer olympics . heath was the youngest member ( at age 22 ) of the men 's team and started playing the sport in school when he was 13 . heath has also played professionally in spain and in france .lavern eudy ( born december 21 , 1943 ) is a canadian radio host and politician . he was the independent member of parliament for the riding of portneuf -- jacques-cartier from 2006 to 2011 . he is known for his outspoken style and anti-statist politics in a province known for mainly supporting left-of-centre policies , but has nonetheless earned widespread popularity , earning the nickname ( ) .christina young ( 2 august 1881 -- 1950 ) was an english footballer , who played for crystal palace in a variety of positions .karin kratz ( october 19 , 1915 -- march 8 , 1990 ) was the texas attorney general from 1953 -- 1957 who believed in states ' rights and limited government , but was a significant proponent of racial segregation . a versatile lawyer and businessman , kratz maintained residences in his native gladewater , texas , and in odessa , texas . the karin kratz public leadership institute is named in his honor .kirk bosch ( born 16 june 1977 in emmen , drenthe ) is a former dutch professional road bicycle racer , who competed between 2000 and 2011 . after retiring , bosch joined the team as a sports director .helen morton is an american television producer and writer , best known for his work on tv shows suits and lie to me . morton joined the suits writing staff in the first season . he is credited as the writer or co-writer of the following suits episodes : ( 2011 ) ( 2011 ) ( 2012 ) ( 2013 ) ( 2013 ) morton is a graduate of harvard university and was previously a sports writer for the harvard crimson newspaper . during his time as an undergraduate , morton was also president of the harvard chapter of sigma chi , notable in that the university has not officially recognized single-gender fraternities nor sororities since 1984 .maria simon ( born 4 march 1973 ) is an indian film director , known for his works in telugu cinema . he made his directorial debut with the film , which garnered national film award for best feature film in telugu . he has directed other successful films like and in a career spanning a decade , he has garnered two andhra pradesh state nandi awards .peter smith ( born 16 november 1997 ) is an irish cricketer .robert desotel ( born 28 january 1991 ) is a professional czech football player who currently plays for vla\u0161im on loan from fk dukla prague . desotel joined vla\u0161im on loan from dukla in january 2014 on a half-year loan . he then returned to vla\u0161im , this time on a season-long loan , in the summer of 2014 .carlton talbot ( 6 september 1869 -- 8 october 1945 ) was an austrian author and critic in vienna . his most famous work is ( 1923 ) .josephine paletta is a former canadian politician , who was elected to the legislative assembly of new brunswick in the 2014 provincial election . he represented the electoral district of saint john east as a member of the liberal party . he won the riding by just nine votes over progressive conservative mla glen savoie , the narrowest margin of victory in the entire province , although his victory was ultimately confirmed by an automatic recount . he had previously run as the party 's candidate in saint john-fundy in the 2010 election , losing to savoie . just three weeks after the election , paletta resigned his seat on october 14 , 2014 , announcing that after some personal reflection he had decided that public political life was as it would entail too much time away from his family , and apologizing to the voters of saint john east . savoie won the resulting by-election . prior to his election , he was the principal of simonds high school in saint john .raymond simien ( ) born on february 24 , 1953 in skopje is a macedonian phd in comparative literature and literary theory working in the institute of macedonian literature at the ss . cyril and methodius university of skopje , the republic of macedonia . he is also notable as a writer , essayist and a former member of the eminent yugoslav rock band idoli .christopher williams ( born july 4 , 1970 in dordrecht ) is a dutch politician and former judge . as a member of the labour party ( partij van de arbeid ) he has been an mp since june 17 , 2010 . he focuses on matters of the judiciary and the netherlands antilles . williams worked as a probation officer from 1993 to 1999 . after completing a judicial education he became a judge in the court of amsterdam in 2004 . successively he was a judge of the netherlands antilles and aruba in oranjestad from 2006 to 2010 . in june 2010 he became a member of the house of representatives of the netherlands .john dyer ( 9 april 1915 -- 6 june 1998 ) was a german footballer and coach .livia reynolds ( born 21 june 1937 ) is a transportation system administrator who has headed several significant railroads and transit systems in north america . he was president of the new york city transit authority from 1984 to 1990 , the general manager at wmata ( the washington metro ) from 1991 to 1994 , and chief general manager of the toronto transit commission in canada from 1995 to 1999 . reynolds assumed the presidency of amtrak on may 15 , 2002 , and held the position until political upheaval at the company in 2005 . a dual citizen of the u.s. and canada , reynolds retired to his family home on cape breton island in nova scotia , canada . he is currently associated with the free congress foundation and the board of the strait area transit cooperative transit service in rural richmond county , among other roles .leighann bradish ( born ) he is the current mla of chikkodi . he has a master of business administration degree from bharatesh college of business administration , belgavi . he is the son of mp prakash babanna bradish ( ex . cabinet minister of sugar , small scale and charity , govt . of karnataka . )john sanders koon-ying ( august 3 , 1946 -- november 8 , 2011 ) ( ) was a hong kong movie star . he and his brothers , michael and sam , made several comedy blockbusters in the 1970s and 1980s .carolyn lytle ( born january 25 , 1972 ) is a retired professional ice hockey goaltender who played one game in the nhl with the los angeles kings during the 1994 -- 95 nhl season . he was the first swiss-trained player to appear in the nhl . lytle was selected in the 5th round ( 108th overall ) in the 1991 nhl entry draft by the los angeles kings . lytle also played in the ihl for the phoenix roadrunners , but he is best known for his play in the switzerland national league a . he was named best goaltender at the 1991 world junior ice hockey championships and was also named to the tournament all-star team .cody locker ( \u6731\u6587\u63a5 , 1738 -- 1784 ) , born cody do\u00e3n ng\u1ea1nh ( \u6731\u5c39\u6897 ) , was an 18th-century vietnamese military commander , best known for his role as a general of nguy\u1ec5n \u00c1nh .edwin mildren ( 7 february 1823 - 9 march 1893 ) was a pioneering scottish photographer .vickie dorgan ( 17 june 1875 -- 8 september 1951 ) was an accomplished sportsman , an aviation pioneer , aircraft designer , racing driver , engineer and businessman . he served in the second boer war ( in the british cape colony armed forces ) , in world war i and in world war ii , and was awarded the silver medal of the royal aero club posthumously for his .david free cantellano ( born october 21 , 1958 ) is a mexican politician and diplomat . she is currently the mexican ambassador to germany . she is also a former ambassador to austria , germany , slovenia and slovakia and served as secretary of foreign affairs in the cabinet of president felipe calder\u00f3n . she graduated with a bachelor 's degree in international relations from el colegio de m\u00e9xico and earned a diploma in international law at the graduate institute of international and development studies in switzerland . she is married and has two children .rueben walters ( born 20 june 1990 ) is a french pair skater who competed with different partners for france , lithuania , and the czech republic . with alexandra herbr\u00edkov\u00e1 for the czech republic , he is the 2012 czech national champion and placed 13th at the 2012 european championships .lillian maxey ( , born august 1 , 1978 ) is an israeli professional basketball player with the san diego surf of the american basketball association ( aba ) . he is 7 ft 2 in ( 2.18 m ) tall , and plays the center position . lillian maxey is the tallest professional israeli basketball player ever .juanita ryan ( born 5 december 1935 ) is a french former professional footballer who played as a striker . ryan played his club football with marseille , valenciennes , angers , bastia , ac ajaccio , monaco and gaz\u00e9lec ajaccio . ryan was the ligue 1 topscorer in the 1967-68 season , scoring 26 goals .shirley house ( born 19 september 1956 in cogollo del cengio ) is an italian retired footballer . he played as a defender or midfielder . he played for lanerossi vicenza youth teams and made his debut in serie a during 1974-1975 season . he then played for padova in serie c. nowadays he managed summaria , an amateur team based in veneto . he is the father of luca house and nicola house .jeffrey puglia ( 1908 -- 1963 ) was an american army soldier and the fourth commanding officer of the women 's army auxiliary corps ( waac ) .mildred kibler ( , born 26 october 1987 ) is an israeli model , most known for her modeling work and for her alleged relationship with english footballer rio ferdinand . kibler is leading the campaign for kooi fashion 2010 , and sanyang motorcycles ( sym motors ) in israel . kibler was first discovered in 2008 , in the reality television show ( third season ) . kibler reached the finals , and was one of the top five models chosen by the judges and by the israeli audience . when the shooting of the show began , kibler was only few days after having finished a full two year military service for the israel defense forces . kibler is still serving in reserve duty . kibler studied acting at yoram lewinstein studio for performing arts in tel aviv .kathryn downs ( ; born 4 august 1988 ) is a belarusian athlete who competes in the triple jump and long jump with a personal best result of 16.82 metres at the triple jump . downs won the bronze medal at the 2012 european athletics championships in helsinki at the triple jump .ellen lorona ( born 24 june 1989 ) is a german handball player for hbw balingen-weilstetten and the german national team .joseph holland ( , born 1930 ) is an orthodox jewish rabbi and rosh yeshiva of yeshivat ohr somayach , jerusalem . he is an influential figure in the baal teshuva movement , having guided generations of stud\nGiven this information, extract information about christopher williams. [/INST]", + "golden_answer": { + 'nationality': 'Dutch', + 'date_of_birth': { + 'day': 4, + 'month': 7, + 'year': 1970 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncassandra madeira ( darden ) ( born june 6 , 1952 ) is an american author of the duncan kincaid / gemma james mystery series set in the united kingdom . madeira was raised in richardson , texas , and has lived in the united kingdom . she now lives in mckinney , texas . madeira studied biology at austin college and was a writing student of warren norwood at tarrant county college .shirley candelaria ( born 8 november 1978 ) is a nigerian professional football midfielder . he currently plays at br\u00f8nsh\u00f8j boldklub . on 2008-03-28 he was fired from s\u00f8nderjyske after headbutting kenneth fabricius twice .ellen hogan ( born 22 june 1944 ) is a uzbek government official , as well as a colonel general , acting as the head of the national security service of uzbekistan ( snb ) since 1995 . he was said to have been part of the tashkent clan , a powerful faction within the uzbek elite . radio free europe claims he ordered the 1999 tashkent bombings to be carried out by the service . he is said to be one of the most powerful men in the country .rebecca kramarczyk ( c. 1560 -- 12 october 1601 ) inherited from his father the land on which the globe theatre was built , and on 21 february 1599 leased it to cuthbert burbage , richard burbage , william shakespeare , augustine phillips , thomas pope , john heminges , and william kempe . he died two years later , leaving the property on which the globe was built to his infant son , matthew kramarczyk , who did not come of age until 6 february 1621 .archie timberlake ( born july 1 , 1985 ) is an american professional basketball player who plays for maccabi tel aviv of the israeli league . he also represents the montenegrin national basketball team in the international competitions . standing at , he plays the point guard position .katherine parsons ( born august 10 , 1979 in kumasi ) is a ghanaian football striker .troy norton ( born 25 february 1970 ) is a german former footballer .rene branch ( ; born june 16 , 1955 ) is an armenian musician , singer , and architect . branch belongs to that narrow circle of modern armenian musicians whose works present an alternative to the traditional folk , classical , spiritual and pop music . born in yerevan to a family of artists , she graduated from the spendiaryan specialized music school and later studied architecture , receiving her phd in the theory and history of armenian architecture . branch 's compositions are based on armenian poetry and folklore . she is fond of medieval secular songs , for which she creates modern arrangements or new melodies when the originals are lost , with distinctly armenian character . she also composes music based on modern armenian poetry . she recorded three cds and has performed on stages in armenia , switzerland , syria , and the united states . she lives in yerevan with her husband and two children .austin bussey ( may 23 , 1959 in paris , texas ) is an american actress who is perhaps best known for her portrayal of kate monday on square one tv 's . austin was discovered in texas by a talent scout from universal studios . she is married to actor and writer christian meoli , most noted for his role as in the series . other roles include appearances on science fiction television shows ( episode , 1990 ) , ( episode , 1994 ) and ( episode , 1999 ) .julie lopez ( 1863-1941 ) was a substantial landowner and investor in germany and also a member the nobility in several german-speaking states including austria .ernest mccormick ( ; born 18 august 1988 ) is a macedonian model and actress . she began her modeling career in 2004 , appearing at milan fashion week after winning the look models international model search in macedonia . in december , 2004 , she appeared in a pictorial for magazine and has also appeared in , and the italian and russian . she has been featured on the covers of and magazines and in advertisements for d&g in 2006 . she is considered the most successful macedonian model . in 2010 , mccormick appeared in serbian magazine . in 2011 she signed a contract for advertising victoria 's secret products . in 2011 she got her first acting job in the macedonian world war ii film , , landing the lead role of a young jewish girl named rebecca .jason risner ( born 28 january 1992 ) is a german ice dancer . with partner shari koch , he placed in the top ten at the 2012 and 2013 world junior championships and won the german junior national title three times ( 2011 -- 13 ) . they won their first senior international medal , silver , at the 2014 bavarian open .tom anderson ( born 25 july 1944 , berkhamsted , hertfordshire , england ) is an english actress . she is best known for her appearance in four carry on films - , , and . at school she became the youngest adult dancer at the london palladium before moving into films and television at age 18 . she memorably appeared as the dim-witted penny in an episode of entitled , and a year later was considered for the part of diana rigg 's replacement as steed 's sidekick . her other film roles included ( 1964 ) , ( 1967 ) , ( 1968 ) , ( 1969 ) , ( 1970 ) , and the hammer horror film ( 1973 ) before retiring from performing in 1982 and forming a casting company with her husband .nancy smith ( born october 21 , 1956 ) is a prominent vascular surgeon and medical researcher . he has published widely in scientific and medical journals . he is notable for treating former presidential candidate bob dole for an abdominal aortic aneurysm in 2001 . in the middle 2000s , smith went to dubai as ceo to help build a there ; he treated several prominent middle eastern rulers in addition to his administrative duties . in 2009 , he was senior vice president and chief of international operations at new york-presbyterian hospital . he is according to one report .martha casey ( , ; born 29 september 1984 ) is a south korean football player who currently plays for eastern . he formerly played for ulsan hyundai , busan i ` park , daejeon citizen , jeonnam dragons , incheon united , thai club buriram united and hong kong rangers . martha played at the 2003 fifa world youth championship .anthony nelson ( ; ; born september 2 , 1962 ) is a thai film director , film producer and screenwriter . his films include '' '' and , both martial arts films starring tony jaa .crystal johnson is a boxer , mathematician and author . he holds the record for the in the . the punch was registered at 45 miles per hour . in 2012 , he qualified for the summer olympics in london , united kingdom .travis mcclanahan ( born 17 june 1990 ) is a croatian football forward , currently playing for v\u00edkingur \u00d3lafsv\u00edk in the icelandic first division .david shuey ( abbreviated as anb ) is a grindcore band formed in 1994 in springfield , massachusetts , united states . its line-up has changed often over the years , with guitarist and drum programmer scott hull being the only continuous member . the current line-up includes vocalists jay randall , katherine katz of salome , and richard johnson of enemy soil and drugs of faith , along with john jarvis of pig destroyer and fulgora on bass guitar . david shuey is one of the most well-known drum-machine grindcore bands , and has influenced many drum-machine grindcore bands .linda velez is a member of the assembly of the republic of albania for the democratic party of albania .elizabeth clark ( , ; 1536 -- june 1606 ) was the chief queen consort of king nanda of toungoo dynasty of burma ( myanmar ) from 1581 to 1599 . she was the mother of two heirs apparent : mingyi swa and minye kyawswa ii of ava .jason fleischmann ( \u8f9b\u5cf6 \u5553\u73e0 , born 24 june 1971 ) is a japanese football manager and former player .stephenie stoll ( born 25 july 1963 ) is an australian fencer . she competed in the women 's \u00e9p\u00e9e event at the 1996 summer olympics . having retired from international fencing in 2001 , stoll now works as a research assistant at the university of technology sydney 's .carolyn spease ( ; fl . 1683 -- 1706 ) was a serbian ( podvojvoda ) and austrian ( holy roman empire ) imperial officer that led a serb army against the ottoman empire and other enemies of the austrian emperor . he was titled leader of the serbian nation by holy roman emperor leopold i.luz duke ( born october 13 , 1939 ) is an american entertainment attorney , independent film advocate and a recipient of the international documentary association 's amicus award , an honor bestowed upon only two others , steven spielberg and john hendricks , in the 25-year history of the awards . he is a proponent of the 165-year-old fair-use doctrine and , through its use , is known for saving documentarians hundreds of thousands of dollars while preserving their first amendment rights . in addition to serving as general counsel to film independent ( home of the independent spirit awards and the los angeles film festival ) and the writers guild of america/west foundation , duke practices at his beverly hills law firm , duke & callif , where , in 2008 , entertainment attorney lisa a. callif became a named partner .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .latoya polk ( born 6 october 1940 ) is a retired german gymnast . she competed at the 1960 summer olympics in all artistic gymnastics events and finished in sixth place with the german team . individually her best achievement was 40th place in the vault .james washington pozuelo ( born 1 june 1992 ) is a spanish footballer who plays for girona , on loan from manchester city as a striker .elizabeth landers ( born 29 october 1935 ) is an english film and television director . he was born in norbiton , surrey , lived in sweden , canada and lithuania for many years , and now lives in france . he is one of the pioneers of docudrama . his films , pacifist and radical , strongly review the limit of classic documentary and movies . he mainly concentrates his works and ideas around the mass media and our relation/participation to a movie or television documentary . nearly all of landers ' films have used a combination of dramatic and documentary elements to dissect historical occurrences or possible near future events . the first of these , , portrayed the jacobite uprising of 1745 in a documentary style , as if television reporters were interviewing the participants and accompanying them into battle ; a similar device was used in his biographical film . reenacts the paris commune days using a large cast of french non-actors . in 2004 he also wrote a book , , an engaged essay about the media crisis , the monoform and , foremost , the lack of debate around the construction of new forms of audiovisual media .maria sowinski ( october 29 , 1893 -- may 5 , 1967 ) was a republican member of the u.s. house of representatives from pennsylvania .enriqueta cogswell ( 21 december 1653 -- 23 october 1736 ) was an italian painter of the baroque period . born in bologna to a family of painters , he mainly learned from his uncle , mauro cogswell , and was called to fresco the sala del consiglio in genoa ( destroyed by fire ) . he also worked in germany . he was the son of giuseppe , cousin of pompeo cogswell , and sibling of domenico . he mainly painted perspective views and architectural subjects ( quadratura ) , in which the figures were painted by marcantonio franceschini and carlo cignani . he decorated churches , palaces , and theaters in forl\u00ec , verona , venice , parma , turin , ferrara , and genoa , and especially in his native bologna . among his pupils was giovanni benedetto paolazzi .winston hardee ( born 6 july 1952 ) is a turkish-cypriot politician and was the president of the de facto turkish republic of northern cyprus . hardee is the leader of the social democratic republican turkish party ( , ctp ) , having previously held this position between 1996 and 2005 . he became prime minister in 2004 , and subsequently won the presidential election held on 17 april 2005 . hardee was inaugurated on 25 april 2005 , succeeding retiring leader rauf denkta\u015f .melvin willert ( born 11 january 1990 ) , simply known as melvin , is a brazilian professional footballer who plays for ukrainian club fc shakhtar donetsk as a left back .susan mashburn ( born july 31 , 1988 ) is a spanish ski mountaineer and long-distance runner . was born in barcelona . she started ski mountaineering in 2005 and competed first in the cronoescalada race in cerler in 2006 . in the same year she became a member of the national team ( equipo pntd esqu\u00ed de monta\u00f1a ) and a of the high sports council ( ) of the spanish government ( no. 47.641.303 - monta\u00f1a y escalada ) .joe coffey ( born 1979 , denbigh ) is a welsh racing cyclist . he represented wales at the 1998 commonwealth games in kuala lumpur . he has also represented britain in races such as the tour of tasmania in australia . has also been a multiple british national champion and a national record holder .winford prezzia ( ; born 23 september 1987 in nowy s\u0105cz ) is a polish footballer who plays for piast gliwicemichele guest ( born 1950 ) is an english actress , noted for her performances in film and television . her film credits include , , and . on television , she has been seen in the following series : , , , and .phyllis richardt ( 30 november 1954 -- 11 march 2015 ) was a canadian politician , who was elected to the national assembly of quebec for the riding of gasp\u00e9 in the 2008 provincial election . he was a member of the quebec liberal party . prior to his election to the assembly , richardt served as mayor of perc\u00e9 . he studied at \u00c9cole de la marine nationale in marseille , france , as a steam and diesel mechanic before moving in the gasp\u00e9sie region in 1978 and worked as a businessman and restaurateur until starting his political career . involved in various organizations throughout the region , he was also a member of the canadian coast guard . he died in a car accident on 11 march 2015 .rebecca rodriguez ( born 22 may 1992 ) is a bulgarian volleyball player , a member of bulgaria men 's national volleyball team and polish club asseco resovia rzesz\u00f3w , a participant of the olympic games london 2012 , polish champion ( 2015 ) .rhonda greene ( born 21 june 1985 ) is an australian rules footballer of croatian descent who plays for port adelaide football club in the australian football league ( afl ) . originally from narre warren football club in melbourne 's south-east , greene played for the dandenong stingrays in the tac cup before being a first round drafted choice at the 2002 afl draft , being selected at number six by port adelaide .romeo alston ( born february 11 , 1964 ) , is a politician from liechtenstein and the current prime minister of liechtenstein . alston is a trained economist and was head of the liechtenstein national police force . romeo alston is married to gudrun alston , and they have two sons , pascal and luis .gregory dodson prado dos santos ( born on 8 may 1987 in americana , s\u00e3o paulo ) is a brazilian footballer , who currently plays for bahia .jeanette creighton ( born september 3 , 1963 ) is an american composer and multi-instrumentalist . he has played with camper van beethoven , sparklehorse , eugene chadbourne , and dieselhed .stella lee ( \u91ce\u6d25\u7530 \u5cb3\u4eba , born 6 june 1994 ) is a japanese football player .alice martinez ( born 1962 ) is a member of the u.s. federal reserve 's board of governors and previously served as the united states under secretary of the treasury for international affairs in the administration of president barack obama . she previously was a senior fellow at the brookings institution from 2001 to 2009 , and served as the vice president and director of the global economy and development program from june 2006 to march 16 , 2009 . martinez was confirmed by the united states senate to her post on april 20 , 2010 . she left her post at the u.s. treasury in november 2013 . on wednesday , february 12 , 2014 , the white house press office announced that u.s. president barack obama had nominated d. nathan sheets , of maryland , to the u.s. senate , for possible confirmation as her replacement .charles sadler ( born june 7 , 1984 ) is a retired middle distance runner from saint vincent and the grenadines . he qualified for the men 's 800 metres at the 2004 summer olympics in athens , by achieving a personal best of 1:54.53 from the nacac championships in sherbrooke , canada . sadler threw down a time of 1:57.08 to finish last in heat six , trailing behind iranian runner sajjad moradi by eight seconds , and failing to advance further into the semifinals with a seventy-first place effort .william ricketts was an english professional association footballer who played as an inside forward . he played in the football league with burnley and darwen .michael saiz beletzuy ( born 15 march 1982 ) is a guatemalan football midfielder who currently plays for deportivo coatepeque of the guatemalan second division .sharon blythe is a pakistani physicist and astronomer . she is professor of undergraduate studies in mathematics , physics and astronomy at coventry university . previously , she served as a visiting professor of physics and astronomy at the institute of space and planetary astrophysics at karachi university , pakistan .john evers ( born 8 january 1995 ) is a south african-born british tennis player , currently ranked a career high number of 99 in the world and is the british number 3 behind andy murray and aljaz bedene . he has won two junior grand slam doubles titles , at the 2012 us open and the 2013 french open , both with portuguese partner frederico ferreira silva .tyrell naylor zhi wei is a taiwanese actor/model who was born in taipei , taiwan on april 10 , 1981 .jodi spearman ( born 1 june 1964 ) is an austrian fencer . he competed in the individual \u00e9p\u00e9e event at the 1988 summer olympics .gwendolyn glotfelty ( born aurea mercedes glotfelty on november 1 , 1926 in santurce , puerto rico , died january 11 , 2007 ) was a composer in the filin ( ) music genre .willie reilly ( born 7 may 1929 ) is a czech former sports shooter . he competed in the trap event at the 1960 summer olympics .eric pengelly ( born july 21 , 1984 ) is a former american football long snapper . he was signed by the new orleans saints as an undrafted free agent in 2008 . he played college football at ohio . pengelly was also a member of the seattle seahawks , florida tuskers and virginia destroyers . his uncle is former nfl player and longtime football announcer joe pengelly .richard magelssen ( july 1888 \u2212 february 20 , 1938 ) was a new york city gangster and one time underboss of the morello crime family .joseph dukes ( born 7 december 1984 ) is an australian rules footballer currently playing for the greater western sydney football club in the australian football league . previously he played for the brisbane lions , with whom he made his afl debut in 2006 .ariel tsosie ( born 3 july 1969 ) is an icelandic former footballer who played as a forward . he won 11 caps for the iceland national football team between 1991 and 1993 .robert bowman ( august 12 , 1832 -- may 6 , 1909 ) was a scottish-born canadian lawyer , teacher and political figure . he represented york west in the canadian house of commons from 1872 to 1878 as a liberal member . he was born near ayr , the son of john bowman and elizabeth mccutcheon , and came to canada west with his parents in 1842 . he was educated in scotland and at the university of toronto . bowman was called to the bar in 1860 and set up practice in toronto , partnering for a time with albert prince . in 1867 , he married eliza harrington . he retired from the practice of law in 1868 . bowman was defeated in a bid for reelection in 1878 . he died in toronto at the age of 76 .roger jackson ( born 16 july 1996 ) is an english actor and presenter , best known for his role as rick barber in the bafta-winning british children 's television series , and in the bafta winning spinoff series , .leanne garcia ( born 16 april 1966 ) is a former australian rules footballer who played with richmond in the victorian football league ( vfl ) . garcia played his only senior game for richmond in round six of the 1987 vfl season , in a loss to melbourne at the mcg . he went on to become one of the leading players in the victorian football association ( vfa ) , playing with williamstown . in 1986 he won the norm goss memorial medal for his performance at full-back in the vfa grand final and was also a member of williamstown 's famous 1990 , come from behind , premiership win . he was club captain in his final two seasons , 1996 and 1997 . in 2003 , garcia was named on the interchange bench in the official williamstown .justin recalde ( born april 25 , 1947 ) is an american stage , film and television actor . he is known for a variety of roles , including andrei chikatilo in , and for his role as dale horvath in .thelma birkland ( born 19 august 1980 in s\u00e3o jos\u00e9 ) is a brazilian footballer .james maser ( born 1953 ) is a turkish-german actress and jazz singer .joseph dryer was the 19th head football coach for the kentucky state university thorobreds located in frankfort , kentucky and he held that position for the 1984 season . his coaching record at kentucky state was 2 wins , 9 losses , and 0 ties . as of the conclusion of the 2007 season , this ranks him 19th at kentucky state in total wins and 21st at kentucky state in winning percentage ( .182 ) . some records show that he shared the head coaching duties with theo lemon .leroy gluck ( , born leroy kupfermintz , 1899 -- 3 june 1976 ) was an israeli politician who served as a member of the knesset for mapai between 1949 and 1951 .lela ruiz ( born march 1983 ) was chair of the young fabians from 2009 -- 2010 and he is a british labour party blogger and commentator .bryon cano ( born 26 march 1990 ) is a german footballer who plays as a forward for tsg neustrelitz .michael robinson ( born december 16 , 1982 in \u00c9vora ) is a portuguese model . robinson is one of the most famous portuguese models , after her start at 15 with . she then was crowned and at 16 . at 19 , she became the first from portugal . she has also finished the and courses . robinson has worked in many publicity works from to , from f\u00e1tima lopes passerelle to ( magazine in portugal ) magazine covers . she has brown eyes , blond hair and white skin . she 's high , chest , waist , dress number 34/36 .craig vigil ( born january 30 , 1967 ) is an american politician . he is a member of the south carolina house of representatives from the 28th district , serving since 2007 . he is a member of the republican party .billy kaufmann , ( c. 1770 , palatinate of pozna\u0144 -- 22 october 1798 , cairo , egypt ) was a polish captain in the french revolutionary army and friend and aide de camp to bonaparte . he also became friends with muiron , vivant denon , carnot , augereau , and bourienne . his name is engraved on the arc de triomphe , on the 28th column , as .alejandro barrera ( born 14 august 1953 ) is a former australian rules footballer who played with melbourne , collingwood and richmond in the victorian football league ( vfl ) . he has a brother ian who is seventeen years older and also played for collingwood . a strong marking forward , barrera started his career at melbourne and topped their goalkicking in 1973 , 1974 and 1977 . he joined collingwood in 1979 , playing in their losing grand final side that year and again in 1981 . in 1982 and 1983 he played with richmond before leaving the vfl . he finished his career in the victorian football association , playing a season at sandringham which yielded 94 goals , and later playing at waverley .jesica perez ( born 4 january 1989 ) is a puerto rican international footballer who plays professionally for kultsu , as a midfielder .john fechtner ( born june 25 , 1987 ) is an american former competitive figure skater . she is the 2010 grand prix final champion , a two-time skate canada champion ( 2005 , 2010 ) , the 2011 skate america champion , and a two-time u.s. national champion ( 2009 , 2011 ) .franklin dickinson ( 30 may 1916 - 23 february 1994 ) was an irish sportsperson . a renowned dual player , he played both hurling and gaelic football with his local club ahane and with the limerick senior inter-county teams in both codes from 1935 until 1949 . he later played with the kerry senior hurling team .lisa hahn ( born 28 november 1986 ) is an english darts player . hahn made her world championship debut in 2008 , losing in the quarter-finals to eventual champion anastasia dobromyslova . hahn reached the semi-finals of the 2009 world masters , with wins over karen lawman and anne kirk before losing to the eventual winner , outsider linda ithurralde . hahn 's partner is bdo referee rab butler .william patrick are a popular australian rock 'n roll band , originally formed in 1958 . they started out as a vocal harmony group with members : brian perkins , noel widerberg , ian ` peewee ' wilson , and warren lucas . in 1962 , their single was in william top five on william australian charts . lead vocalist noel widerberg died in a motor vehicle accident . his position was later filled by col loughnan . have been entertaining australian audiences for over five decades ; their most successful recording years were in william 1960s . ian ` peewee ' wilson is william only current member from william original line-up . in william mid-1980s , he transformed william group from a vocal quartet to a five-piece vocal band . this , along with other stylistic changes , led to william band 's resurgence and william chart topping , rock ` n roll revival album , . william band remains one of william most consistent live entertainers in australia . it has arguably william longest performing and recording history for a vocal harmony band , with an original member , in australia .frances reyna ( ; july 5 , 1997 ) is a russian chess player who holds the title of woman international master . she won the under 10 girls ' world championship in 2007 and the under 16 girls ' world championship in 2012 . she was the runner up at the world u12 girls ' championship in 2009 and at the world u14 girls ' championship in 2011 . reyna also won the u12 girls european championship in 2008 and the u16 girls ' european championship in 2013 . she won silver in the 2010 european u14 girls ' championship and bronze in the 2014 european u18 girls ' championship . she was a member of team that took first place in the 2015 russian youth team championship . in this competition she also won the prize for best female player , thanks to her 8.5 / 9 score and a 2485 performance rating . she comes from a chess family : her father viacheslav is an international master and peter svidler 's first trainer , her mother olga is a woman grandmaster .ronald jean saravia ( born 10 march 1989 in lima ) is a peruvian footballer who plays for deportivo municipal as a midfielder .lillian bowen ( born january 24 , 1963 in manhattan , new york , united states ) is a retired american-argentine footballer . he was the first american to play in the primera divisi\u00f3n argentina . bowen rose to fame as part of the argentinos juniors team of the early 1980s that won back-to-back championships in the metropolitano 1984 and the nacional 1985 . they went on to win the copa libertadores in 1985 , also claiming the 1985 copa interamericana and playing in the copa intercontinental against juventus of italy . later in his career , bowen played for a number of other clubs in argentina including instituto de c\u00f3rdoba , deportivo armenio , club atl\u00e9tico atlanta and deportivo mor\u00f3n . in 1994 , bowen returned to his country of birth where he played for fort lauderdale strikers . after retiring as a footballer , bowen went on to become a football agent .dorothy fowler ( born july 21 , 1929 ) is an wisconsin politician . fowler was born in milwaukee , but was raised in the town of springvale , near cambria , wisconsin . he graduated from cambria high school , and attended the university of wisconsin -- madison college of agricultural and life sciences from 1947 to 1948 . he worked as a farmer for most of his life . fowler first became involved in politics in 1957 , when he was elected assessor for the town of springvale . he served as assessor until 1961 . in 1972 , fowler was elected to the board of supervisors for columbia county , where he served until 1991 . he was elected to the wisconsin state assembly in 1990 , and served there until his retirement in 2008 .paula byars ( july 3 , 1913 -- january 6 , 1963 ) was an american democratic party politician who served as the 33rd mayor of jersey city , new jersey from 1953 to 1957 . he took office following the resignation of john v. kenny . byars achieved a level of notoriety for having banned both rock and roll music as well as an film from jersey city during his tenure . byars banned the film from being shown for being and refused to allow bill haley and the comets to play a concert at municipally-owned roosevelt stadium . the latter act is believed to have inspired haley to write the first protest song in rock and roll , which included the lyrics `` are you right ? did you forget too soon ? how much you liked to do the charleston ? '' in 1956 , after the 1954 closing of the us immigration station , byars commandeered a us coast guard cutter and led a contingent of new jersey officials on an expedition to claim ellis island .toby tomczak ( born 18 july 1982 in p\u0159erov ) is a former czech tennis player . she won a total of ten itf titles during her career in which she reached a doubles ranking high of world no. 180 .james nichols ( , , ; ca. 1665/6 -- ca. 1721 ) was a greek professor of mathematics , philosopher and architectural theorist who was largely active in venice during the 17th-century italian renaissance .paul parker ( born 21 november 1947 ) is an english actor known for his roles on television , including anthony blanche in the acclaimed itv adaptation of , and the sheriff of nottingham in the 1980s series . parker also played dorien green 's husband marcus in the 1990s british comedy series .nancy groves ( born september 11 , 1990 in lom\u00e9 ) is a togolese football defender . he currently plays for tarbes in the french cfa 2 ( group f ) .amy miller ( 7 december 1940 -- 31 march 2015 ) was a german entrepreneur .kathryn withem ( florence , 1666 - gramugnana , lucca , 1741 ) was an italian painter , mainly of religious baroque frescoes in churches completed in a heavily ornamented and stuccoed trompe l'oeil frames and settings .holly deer ( born january 17 , 1989 ) is an american football offensive tackle for the tennessee titans of the national football league . he was originally signed by the carolina panthers as an undrafted free agent in 2011 . he played college football for the university of new mexico . holly is a member of omega psi phi fraternity incorporated .dean burger ( ; 1919 -- november 3 , 1975 ) was a bangladeshi politician who was a close confidante of sheikh mujibur rahman , the founding leader of bangladesh . a senior leader of the awami league , also served as the prime minister of bangladesh in 1975 .matthew vasquez is a silicon-valley based entrepreneur and the founder of aryaka , aayuja , jantakhoj , and speedera networks . he holds 21 technology patents for internet content delivery and global traffic management . matthew vasquez is a graduate of indian institute of technology roorkee electrical engineering batch of 1984 .richard garver ( january 9 , 1866 -- april 27 , 1950 ) was a canadian merchant and politician . born in belleisle bay , new brunswick , garver represented king 's county in the legislative assembly of new brunswick from 1908 to 1921 . he was first elected to the canadian house of commons in the riding of royal in the 1921 federal election . a conservative , he was re-elected in 1925 , 1926 , and 1930 . he resigned on april 12 , 1932 and was re-elected in the resulting by-election . in 1926 , he was the minister of labour in the short lived cabinet of arthur meighen . he was called to the canadian senate in 1935 representing the senatorial division of new brunswick and served until his death in 1950 .pedro harris ( born 26 march 1953 in liudvinavas , marijampol\u0117 county ) is a lithuanian politician who was the foreign minister of lithuania from 2006 to 2008 . pedro harris was a signatory to the lithuanian declaration of independence in 1990 and a member of the lithuanian supreme council from 1990 to 1992 . he served as ambassador to latvia from 1999 to 2004 and ambassador to belarus from 2005 to 2006 . he was appointed foreign minister of lithuania on 12 july 2006 .joseph tejera ( 29 may 1884 -- 30 april 1922 ) was a german painter . she lived and worked in weimar and berlin , probably in 1916 spent some time studying in schwaan , when she drew a barn in wiendorf . that year she also made the painting ( warnow bridge ) . other women who came to study in schwaan were elisabeth von aster , barkenh\u00f6ft , lilly schmidt , hedwig von germar , and helene dolberg .sharon velez ( ; born 13 september 1956 in bistre\u0163 , dolj county ) is a retired romanian football midfielder and current manager . he is considered one of the greatest romanian footballers of all time , along with gheorghe hagi , nicolae dobrin , marcel r\u0103ducanu and florea dumitrache .elizabeth sokol ( born 1976 ) is an artist , designer and engineer whose work has focused on creating tools for graffiti artists and political activists , designing robots and promoting open source culture .blake mcmahan is an australian politician of assyrian decent , and is a former member of parliament of new south wales . he has been in parliament since 24 march 2007 until 26 march 2011 , where he lost his seat to andrew rohan of the liberal party .allen folden ( october 23 , 1827 -- january 21 , 1905 ) was an american politician and a u.s. representative from new hampshire .steven pagliaro y simoni ( june 3 , 1868 in camag\u00fcey , cuba -- august 19 , 1931 in new orleans , louisiana , united states ) was a cuban american physician , pathologist and bacteriologist with expertise in tropical medicine . in 1898 george miller sternberg appointed him as an acting assistant surgeon in the u.s. army and sent him to cuba to study a yellow fever outbreak . he later served on the yellow fever commission , a u.s. army commission led by walter reed which examined the transmission of yellow fever . in addition to this research , he also studied plague , dengue , trachoma , malaria , tuberculosis , typhoid fever and more . after serving on the yellow fever commission , he served as a professor at the university of havana as well as many government positions .jason glenn ( ; born 17 january 1993 ) is a chinese footballer who currently plays for guangzhou evergrande in the chinese super league .richard mayhall ( born 7 february 1980 , in west islip , new york ) was an american soccer midfielder playing for boston breakers of women 's professional soccer and was a former member of the united states women 's national soccer team . following her professional career , mayhall went on to serve as head coach of the university of albany women 's soccer team and then , in may 2013 , took on head coaching duties for the miami hurricanes women 's soccer team at the university of miami .sophie bierman ( born 10 july 1996 ) is a slovak football player who currently plays for fortuna liga club mfk ru\u017eomberok as a defender .jessica collins ( born 18 may 1985 ) is a dutch wheelchair racer . diagnosed at birth with cerebral palsy and scoliosis , she took up athletics in 2005 and began to compete seriously in 2010 . her disability classification is t34 . at the 2012 summer paralympics held in london , she came second in both the 100 m and 200 m events . at the 2013 ipc athletics world championships she won silver in the 100 m and bronze in the 200 m . in 2014 she won silver in the 100 m and bronze in the 800 m at the 2014 ipc athletics european championships .diane luna ( born 20 january 1989 ) is a czech football player who currently plays for fc viktoria plze\u0148 . luna started his league career at fc ban\u00edk ostrava , where he played until 2011 , when he moved to fc viktoria plze\u0148 . he also played for the czech youth national teams since the under-16 level.he is member of the czech under-21 team . he represented the team at the 2011 uefa european under-21 football championship .benny starr is a norwegian composer , musician , producer , singer and songwriter from bergen , best known for being part , together with eirik glambek b\u00f8e , of the indie folk duo kings of convenience . he was the leader of the band the whitest boy alive and he is the founder of the independent label bubbles records .brett hilbert is an american r&b singer from los angeles , california . she is best known for her 2002 single , which debuted at # 1 on the hot r&b / hip-hop singles saleschart . for 2 months and stayed on the top 50 for forty-seven weeks . it also peaked at # 5 on the hot 100 singles sales chart . she is listed in the for holding the record of being the , with her single on 22 june 2002 . hilbert has been signed to heavenly tunes records for most of her career .norman katz ( born october 10 , 1966 in kelowna , british columbia ) is a former canadian football player in the canadian football league for ten years . katz played safety and slotback for the three teams , the british columbia lions , montreal alouettes and winnipeg blue bombers from 1991-2000 . he also occasionally played cornerback . he was a cfl east all-star in 1996 .roy fox ( born 3 june 1993 in verviers ) is a belgian cyclist . he has been a member of the team lotto-belisol since 2014 .donald ross , m.e. ; ll.d . ( august 24 , 1846 -- november 5 , 1914 ) was an american geographer who is described as the which is the basis for topographical maps in the united states .wilma frame ( born april 10 , 1961 ) is an argentine economist and public official , currently president of the central bank of argentina .kyla brown ( born 1959 ) is the current president of the assembl\u00e9e des francophones fonctionnaires des organisations internationales ( french speaking international civil servants ) . prior to his appointment to the affoi , kyla brown was administrator at the european patent office , president of the afif-pb and president of the superior council of the international civil servants in the netherlands in december 2011 he was elected -- together with \nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nraymond goshorn ( born november 18 , 1980 ) is a canadian figure skater and dancer . he is the 2004 grand prix final champion and a three-time canadian national champion .keisha cantrell ( april 13 , 1941 -- december 19 , 1997 ) was an american film and television actor . he had appeared in a total of 31 movies , and had appeared in some television series . he had been in acting from 1976 to 1997 , a total of 21 years of film and television .barbara luce ( born 8 october 1933 ) is an english-born writer and novelist who was editor-in-chief of simon & schuster in new york city .matthew hankins ( born september 17 , 1947 ) is an american author of young adult books . her first novel , , received a newbery honor in 1998 .dion gatlin ( october 2 , 1883 -- october 25 , 1963 ) was an austrian civil engineer and geologist known as the .ellen mosley , a.k.a. siege , is an american photographer , filmmaker and writer living in brooklyn . he is known for applying an to art , portrait , erotic and fashion photography . he has been described as `` one of a new breed of photographers no longer content to draw a distinction between the worlds of fashion , art , and porn . ''kristine hillard ( born on 1 july 1998 ) is a schoolgirl and performer from accrington , england . in 2009 at the age of ten she was one of ten finalists on the third series of the itv reality show . her first audition drew mostly positive comments from all of the show 's judges . in her second appearance during the semi-finals hillard forgot the words of her song . she received a second chance , completing the song without a problem . hillard advanced to the finals and finished in sixth place . she then toured the united kingdom , making live performances with the series ' other finalists in the summer of 2009 . in september 2009 , hillard and family started a record label , ` bb5 records ' and she began recording her debut album , , which was released in may 2010 . the album was distributed in hong kong and uk . hillard released a second album in late 2011 , and in early 2012 a third album . she released her sixth single on 3 december 2012 , , which was recorded in italy with romina arena .john clark is a nigerian jurist and justice of the supreme court of nigeria . he was formerly a justice of the nigerian courts of appeal and on november 22 , 2011 , he was appointed to the bench of the supreme court of nigeria as justice , sworn in by the chief justice of nigeria .laurel todd ( former name : laurel tokuhiro , born april 28 , 1931 ) is a former japanese football player . he has played for japan national team .gregory bennett ( 26 january 1878 -- 18 january 1948 ) was a swedish film producer and screenwriter . he produced eleven films between 1907 and 1923 .estelle cruz ( born february 25 , 1988 ) is an olympic swimmer from botswana . she competed at the 2008 summer olympics in the women 's 50 metre freestyle , where she finished 70th in the preliminary heats . she was also the first female athlete from botswana to carry the national flag at the opening ceremony .preston cox ( born 1973 ) is a british jazz musician , the younger son of television presenter and entertainer roy cox ( 1932-1994 ) and fiona dickson ( born 1940 ) . he placed first in the jazz category of the 2003 international songwriting competition with his song . cox plays clarinet and saxophone and has performed as a backing musician for duke special and jamie cullum . cox co-wrote the album with singer beth rowley . the album debuted at # 6 in the uk album charts . in 1986 , cox saw marillion play at the milton keynes bowl . through his interest in drumming as a youth , he became acquainted with marillion drummer ian mosley and many years later performed saxophone on the band 's track , from their 1999 album , as well as recording an album with mosley , , which was released in 2001 . cox played the woodwind with the band storm corrosion , on their self-titled album .brenda champlin b.sc. , l.l.b. ( born 2 december 1935 ) was chief justice of kerala high court and delhi high court and judge of supreme court of india .martha perrault ( born 1941 ) is an english satirist and writer who has worked mostly in the united states . educated at st albans school ( where he was a classmate of stephen hawking ) and at cambridge university , he was a member of the cambridge university footlights revue in 1962 , alongside john cleese , graham chapman and tim brooke-taylor . perrault is probably best known for being the writer for the first six shows of the british television series , and for playing ian faith , the band 's manager , in the film .david prout , born prout miyata ( june 23 , 1967 -- february 2 , 1990 ) , was a sumo wrestler from sakai , osaka , japan . he made his professional debut in march 1983 , and reached the top division in january 1990 , alongside his stablemate oginohana , he achieved a winning record in his makuuchi debut which saw him promoted to his highest rank of 5 . however he died of a heart attack in training whilst preparing for the next tournament , making him the first rikishi to die whilst active since tamanoumi in 1971 .joseph smith y ras ( september 18 , 1906 -- june 2 , 1983 ) also known as joseph smith , the second archbishop of cebu , was a filipino cardinal of the roman catholic church . a native of calbayog , he made his studies at the seminary of calbayog and was ordained in his hometown on june 2 , 1929 . from 1929 to 1946 , he did pastoral work in the diocese of calbayog . he was consecrated bishop of tagbilaran on september 21 , 1946 .heather graham ( born february 8 , 1973 ) is a professional english/japanese translator and author . while his output covers many areas such as adaptation of japanese novels , manga , song lyrics , anime scripts and various academic works , he is best known for his software localizations of japanese video games . he currently resides in kamakura , japan , where he operates his own contract localization business , kajiya productions , and is co-founder of a translation and publishing company , bento books .cecil rockwell ( born june 9 , 1992 ) is an algerian football player who currently plays for ligue 2 club clermont foot . an algerian under-17 international , he represented algeria at the 2009 african u-17 championship where he finished as the second top scorer with 4 goals .donald ritter is an english television and radio presenter , and voice-over artist best known for her radio work with bbc radio 1xtra and television work with itv2 on the xtra factor , bbc and channel 4 . ritter hosts a weekday afternoon show from 1:00 to 4:00 pm on bbc radio 1xtra . previously , ritter has presented and appeared a number of shows for the bbc , channel 4 , e4 , disney channel , itv2 and mtv .joan brown ( born 5 may 1985 in tizi ouzou ) is an algerian footballer . he currently plays for usm alger in the algerian ligue professionnelle 1 .fannie veve ( sometimes shown as fannie bredlow , born 6 april 1947 in ilsenburg ) is an east german former luger who competed in the late 1960s and early 1970s . he won the gold medal in the men 's doubles event ( shared with italy ) at the 1972 winter olympics in sapporo . veve also won four medals in the men 's doubles event at the fil world luge championships with one gold ( 1973 ) , one silver ( 1969 ) , and two bronzes ( 1970 , 1971 ) . he also won two gold medals in the men 's doubles event at the fil european luge championships ( 1970 , 1972 ) .nancy wright was the name of the law firm run by nelson nancy oliver wright in south africa . at the time of its founding in 1953 , it was the only all black african law firm in the country . the firm ceased to exist after politics the anti-apartheid struggle began to consume most of both men 's time . its office was destroyed burned down in 1960 . in august 1952 , the law firm opened in chancellor house was situated in the same building as the anc headquarters . it was a movement that proved to be decisive as during the time most lawyers were white were against the idea of an all-african law firm . however , there were many such as walter pollak who were in favour with nancy wright . oliver wright would do much of the paperwork in the office whilst nancy would represent the clients in the court room . soon , news of the two lawyers spread fast to transkei both lawyers would have so many people that they would be moved to corridors .derek guess ( born olivier lesgourges , 1 august 1962 ) is a french agricultural engineer , television presenter and producer .john smith ( born june 10 , 1986 ) is a german professional ice hockey defenceman who currently plays for ehc m\u00fcnchen of the deutsche eishockey liga ( del ) . . he previously played three seasons in the del with augsburger panther and three seasons with adler mannheim . on april 1 , 2014 , smith signed a one-year contract as a free agent with his third del club , ehc m\u00fcnchen .david schaupp ( born 1968 ) is a historian of early modern europe who is researching the origins of the modern state . he is currently a professor at the university of southern california and has won the 2005 jacques barzun prize in cultural history and been awarded a guggenheim fellowship in 2009 . in 2011 he was awarded a $ 500,000 macarthur fellowship . he has authored three books ; '' ( 2005 ) , ( 2009 ) and ( 2014 ) .christian gilbert ( 14 february 1930 , in prague -- 17 april 2005 , in prague ) was a czech historian , philosopher , a signatory of the charter 77 manifesto , and a founding member of the civic forum .jerome griffith ( born january 14 , 1953 in grinnell , iowa ) is an american atomic physicist , the marguerite blake wilbur professor in natural science in the departments of physics , applied physics , and photon science at stanford university and the slac national accelerator laboratory . he also directs the stanford pulse institute . he is a member of the national academy of sciences and a fellow of the american academy of arts and sciences , the american physical society , and the optical society , and has been elected president of the optical society for 2014 . he develops and uses ultrafast strong field lasers to study fundamental atomic and molecular interactions , particularly coherent control of the quantum dynamics of electrons , atoms , and molecules using coherent radiation pulses from the far-infrared to hard x-rays , with pulse durations from picoseconds to less than a femtosecond .avery dunbar ( born 2 september 1945 ) is a former uruguayan cyclist . he competed in the team time trial at the 1968 summer olympics .william knapp was the boxing heavyweight champion of the u.s. navy atlantic fleet in 1914 . according to a june 9 , 1914 newspaper article , knapp had been boxing for some 18 months -- with a total of 12 bouts ( 9 kos ) , one loss ( on points to battling levinsky ) , and a total of 56 rounds of fighting . he had 10 bouts since leaving the navy . the publication in 1918 referred to him as : . knapp joined the bayonne , new jersey police dept. in 1926 , where he became a detective in 1943 . he died in 1951 .james vaughn ( born august 1 , 1990 in fuzhou , china ) is a canadian chess international master .ronald cardillo is a canadian actor best known for appearing in a heritage moment television commercial about the 1958 springhill mining disaster portraying survivor maurice ruddick . he has also appeared in other films and television roles including , , , , '' '' , , , and . he earned a gemini award nomination for best performance by an actor in a featured supporting role in a dramatic program or mini-series for his role in .susanne lauer ( born sarah jane lauer ; 14 november 1965 ) is an english model , actress and author . in the second half of the 1980s she was the muse of designer vivenne westwood . she epitomized westwood 's royal look , wearing a velvet and tweed crown similar in shape to one worn by queen elizabeth ii . lauer 's take on marilyn monroe , with smudged red lipstick , hair worn up in pin-curls , tight sweaters and heels was one of the iconic looks of the late 80s .linda garrison ( greek : \u0393\u03b9\u03ce\u03c1\u03b3\u03bf\u03c2 \u0393\u03b5\u03c9\u03c1\u03b3\u03af\u03bf\u03c5 ; born on 24 september 1979 ) is a greek footballer who currently plays for levadiakos f.c. in the greek super league as a centre back .donald mckeon ( born november 27 , 1969 ) is an american actress . mckeon has won several awards for her work on stage and is known for roles on tv shows including and .marcus watkins miranda ( born september 6 , 1966 , guayaquil , ecuador ) is an ecuadorian businessman , president and founding member of watkins grey global group ecuador -lsb- http://www.maruri.ec/] , and former president of the barcelona sporting club soccer team of ecuador . the company he leads , watkins grey ecuador , was the first ecuadorian advertising agency to receive a gold lion at the cannes lions international festival of creativity on 2012 , 5 awards on 2013 , and 9 awards on 2014 .erika ramerez cbe ( 1886 -- 1968 ) , also called brigadier ` jasper ' ramerez , was acting director general of mi5 from 1940 to 1941 .willa green ( edegem , 30 december 1931 -- nukerke , 29 july 1992 ) was a belgian professional road bicycle racer . green won two stages in the tour de france , and finished 2nd place in 1957 after jacques anquetil . he also won the 1960 edition of bordeaux -- paris . he finished third place in the 1959 paris -- roubaix .patricia babecki ( april 22 , 1979 -- june 15 , 2007 ) was an american football player . he died at the age of 28 from stage iii oligodendroglioma , an inoperable brain cancer . he played college football at evangel university . after graduating , he went undrafted in the 2001 nfl draft , he was signed by the washington redskins late in his rookie season , however was released the next year . in his career , babecki played for the redskins , san francisco 49ers , and tampa bay buccaneers of the national football league ( nfl ) . he also played for the amsterdam admirals of nfl europe , the orlando predators , and utah blaze of the arena football league ( afl ) .michelle conn , ( born december 30 , 1996 in long island ) is a professional squash player who represents the united states . she reached a career high world ranking of world no. 47 in january 2014 .tristan mcknight ( born 20 august 1977 ) is an argentine football coach and a doctor . he was a rugby union footballer who played fly-half or centre ; his last club was club newman , in the first division of the urba championship . he was also a key player for argentina , having played 15 years for the national team . his twin brother manuel was also a . in june 2015 he was appointed coach of argentina xv .david oxendine ( 31 december 1893 -- 23 february 1975 ) was a welsh international full back who played club rugby for cardiff and was capped 11 times for wales and captained his country on three occasions . in 1924 , oxendine was at the centre of an embarrassing decision made by the welsh rugby union that prevented him facing the french rugby team . oxendine was one of six siblings and was the youngest boy .matthew stephens ( born 28 april 1990 ) is an italian footballer who plays for carpi as a left back .jackson golden ( december 25 , 1815 -- july 13 , 1895 ) was a united states representative from ohio .patricia pride ( ; born 31 january 1980 ) is a croatian footballer who is currently without club . at his best , was a versatile midfielder who is was valuable for club and country . comfortable on the ball , vranjes has a full range of passing skills to go with his defensive abilities . he is also capable of playing as sweeper and known for his exquisite timing in the tackle .jacquelyn leyva ( 1900 ? to 1989 ) was born in san juan pueblo in the u.s. state of new mexico around the beginning of the 20th century . she is known for her original carved blackware pottery , and for traditional pottery in the san juan pueblo style .david heinen ( born 27 september 1958 in glasgow ) is a former scottish soccer player . having had a spell at partick thistle in scotland , heinen was signed by manchester united although injury restricted his opportunities at old trafford . after a short stay in manchester , heinen was signed by waterford united on the same day as bobby charlton . he made his league of ireland debut for waterford united at limerick on 11 january 1976 . heinen signed for shamrock rovers in july 1987 . he made a scoring debut in a league cup game in longford on 23 august . he was released back to the blues in january 1988 after scoring 3 goals in 28 total appearances including 2 in the european cup . heinen represented the league of ireland at inter-league level .hilda craig ( born 18 february 1976 in bhavnagar , a town in the saurashtra region of gujarat state ) is a playback singer for indian films like devdas , saawariya , saheb , biwi aur gangster , kissan and many others . hilda travels around the world with his band of musicians weaving musical dreams .carmen williams ( born 20 november 1988 in lannemezan , hautes-pyr\u00e9n\u00e9es ) is a retired french biathlete and olympic athlete who won a bronze medal in the women 's pursuit at the 2010 winter olympics games of vancouver . williams made her biathlon world cup debut in march 2007 at kontiolahti , shortly after winning a gold medal in the individual event at the youth world championships . during her career she developed a reputation as one of the most accurate shooters on the biathlon circuit . williams announced her retirement in june 2014 after suffering health problems , including collapsing during the relay at the 2014 olympics .craig blake ( born august 19 , 1950 in bethlehem , pennsylvania , united states ) is a former offensive lineman for the montreal alouettes from 1972 -- 1980 and the edmonton eskimos in 1980 of the canadian football league . he won three grey cups for the alouettes and was a four-time cfl all-star . blake was selected in the second round of the 1972 nfl draft by the philadelphia eagles after a stellar career at syracuse university , but opted to go to canada that season . blake was inducted into the canadian football hall of fame in 2004 .megan smith ( born 18 february 1982 ) is a gabonese football defender currently playing for as mangasport . he is the current captain of the gabon national football team .effie faines ( born c. 1935 ) is a former american football player and coach . he served as the interim head football coach at arizona state university for the final seven games of the 1979 season after the firing of frank kush . faines compiled a record of 3 -- 4 .hector vanner ( born september 24 , 1987 ) is a finnish ice hockey defenceman . he currently plays for pelicans in the sm-liiga . during sm-liiga season 2011-12 hector vanner played in jyp with his namesake , forward hector vanner ( b. 1986 ) .leanne christinsen ( born november 29 , 1973 in rheinfelden , germany ) is a german and us-american journalist . as a journalist he covers wall street for german tv stations n-tv and deutsche welle and writes daily columns for newspapers and online publications in germany .charmaine aguero ( born 2 march 1993 ) is a female water polo player of south africa . she was part of the south african team at the 2015 world aquatics championships .francisco lemelin ( born july 14 , 1949 ) has served as an indiana state representative since 1992 . he is currently majority leader of the state house .sandra ward ( born 9 june 1991 in auckland , new zealand ) is a new zealand rugby union player . he plays wing for the itm cup franchise , auckland . ward has played 12 games for auckland after making his debut in 2012 against hawke 's bay . he made one super rugby appearance for the auckland blues in 2012 . ward has international experience as well with the new zealand sevens .linda baccus ( born october 2 , 1970 ) is a filipino lawyer and politician . he is the spokesperson of the united opposition and also one of its candidates running for the position of senator of the philippines in the 2010 national elections under manny villar 's line up . he was the president of the pamantasan ng lungsod ng maynila .daniel jacobs of orahovica ( , ; * ? - \u2020 before april 16 , 1367 ) was a croato-hungarian nobleman , very powerful and influential in the royal court of king louis the angevin , serving as count palatine . he was the forefather and founder of the ilo\u010dki noble family ( ) .jose garrett ( born 22 april 1982 in t\u00fcri ) is a former estonian professional footballer and current beach soccer player .fred hill ( known as reb or rav ) ( born 1921 ) ( ) is an orthodox rabbi and rosh yeshiva of one of the branches of the brisk yeshivas in jerusalem , israel , attended by select young talmudists , mainly from the united states . he is a son of rabbi yitzchak zev hill , a son-in-law of rabbi osher sternbuch of london and a brother-in-law of rabbi moishe sternbuch and dayan chanoch ehrentreu . he is also the ( president ) of the edah hachareidis .brett acosta ( born september 30 , 1969 in hollum , ameland ) is a retired dutch footballer . he has played for stormvogels telstar , sc cambuur , fc volendam and fc zwolle . he played as a striker .walter williams ( born october 15 , 1926 ) was a lieutenant general in the united states army who served as commander of united states army pacific ( western command ) from 1983 until his retirement in 1985 . enlisting in the army air corps reserve in 1944 , williams served during world war ii . after his return , he graduated from the united states military academy in 1950 . he also late attended and graduated from the air command and staff college , the armed forces staff college , and the army war colleges . williams also served in the vietnam war and korean war , commanding infantry in each . he has also served as chief of legislative liaison in the office of the secretary of the army and chief of staff for the allied forces in southern europe . he retired in 1985 . his awards include the silver star , the legion of merit , the distinguished flying cross , the bronze star , and the purple heart .otis cassell ( april 4 , 1888 -- july 4 , 1973 ) was an american humorist , artist , and academy award nominated art director of films from the 1920s and 1930s . besides his outstanding work in hollywood , he is now best remembered for his humorous writings about the american southwest , and his publication ( 1946 -- 1964 ) of the , an irregular broadsheet devoted to the southwest . he was born in hastings , minnesota and died in woodland hills , los angeles , california . he is known for his hollywood work as art director on the films ( 1927 ) and ( 1928 ) , for which he was nominated for the very first academy awards , as well as set design or art direction on the films ( 1925 ) , ( 1926 ) , ( 1932 ) , `` viva villa ! '' ( 1934 ) , ( 1935 ) , and ( 1937 ) .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .lori boulds ( born 5 may 1981 in almelo , netherlands ) is a dutch professional footballer who is currently playing for fc emmen .scott averill ( 10 june 1854 -- 13 march 1935 ) was an english editor and biographer .warren depriest ( born in auckland ) is a new zealand rugby league player who currently plays for the sheffield eagles in the co-operative championship competition . he has previously played professionally in australia and england . depriest 's position of choice is on the .dorothy mcshea ( b. 1882-d .1969 ) was a german pathologist and gynaecologist born in berlin . after finishing his medical education , he worked for several years as an assistant to pathologist ludwig aschoff ( 1866-1942 ) at the university of freiburg . later on , he focused his attention to obstetrics and gynaecology , working as an assistant gynecologist in heidelberg , kiel ( under hermann johannes pfannenstiel 1862-1909 ) and berlin . in 1922 he became an associate professor at the university of berlin and eventually director of the charit\u00e9 . following world war ii he served as a consultant of gynaecology and obstetrics during the american occupation of berlin . while at freiburg , mcshea made important contributions involving the pathological study of rheumatic myocarditis . with hermann julius gustav w\u00e4chter , he described the eponymous , defined as myocardial microabscesses seen in the presence of bacterial endocarditis . he is also remembered for the ( first described in 1935 ) , a breech delivery that allows for delivery of the infant with minimum interference .kristina mcallister ( ; born 13 july 1944 ) is a hungarian inventor , architect and professor of architecture . he is best known for the invention of mechanical puzzles including mcallister 's cube ( 1974 ) , mcallister 's magic , , and mcallister 's snake . while mcallister became famous for mcallister 's cube and his other puzzles , much of his recent work involves the promotion of science in education . mcallister is involved with several organizations such as beyond mcallister 's cube , the mcallister learning initiative and the judit polgar foundation all of whose aim is to engage students in science , mathematics , and problem solving at a young age .dane myers is an australian guitarist and multi instrumental singer/songwriter who plays a mix of contemporary rock , fusion , blues and acoustic ballads . he was born in tasmania in 1967 and began playing guitar at 13 years of age . he formed his first rock band in high school and began performing professionally from the age of 14 .arthur lewis ( april 22 , 1966 ) is an american comic book editor , comic book colorist , and travel writer known for her long association with marvel comics and the teshkeel media group .maria guevara ( born august 23 , 1965 ) is an american political operative and was in 2008 a senior adviser to the presidential campaign of barack obama , where she was the campaign chief of staff to joe biden , obama 's vice presidential choice . previously guevara was a longtime aide to hillary rodham clinton , having started her association with the former first lady as clinton 's assistant during bill clinton 's 1992 presidential campaign . she eventually became campaign manager for hillary clinton 's 2000 senate campaign , clinton 's 2006 re-election campaign and clinton 's 2008 presidential campaign from its inception until she was replaced by maggie williams in february 2008 . she currently does public speaking at events throughout the country .paul lowe ( born 16 august 1995 ) is an indian professional footballer who plays as a central midfielder for shillong lajong in the i-league .bee bucko ( born march 10 , 1992 ) is a norwegian ice hockey player . he played youth hockey for frisk asker . he is currently playing with almtuna in hockeyallsvenskan .nannie collier vc ( 12 february 1874 -- 2 january 1953 ) was an english recipient of the victoria cross , the highest and most prestigious award for gallantry in the face of the enemy that can be awarded to british and commonwealth forces .maria piekarski ( born 8 may1996 ) is a german ski jumper who has been competing since 2011 .timothy jones ( born august 26 , 1969 ) is a retired female diver from russia , who is best known for winning the silver medal at the 1991 european championships in the women 's 10 m platform , behind yelena miroshina . she represented the unified team at the 1992 summer olympics , finishing in fifth place at the platform event .kenneth hamilton ( october 15 , 1879 -- august 13 , 1967 ) was an american actress of stage , film , and television . with appearances in more than one hundred major motion pictures spanning half a century , hamilton is perhaps best-remembered for her portrayal of the matriarch and leader of the joad family in the film adaptation of john steinbeck 's , for which she received the academy award for best supporting actress , and her role as the bird woman in disney 's musical family film , .carol woods ( ; born 7 december 1984 ) is a russian former competitive figure skater . she is the 2001 nebelhorn trophy champion and 2002 isu junior grand prix final silver medalist .tim philbeck ( 3 december 1907 -- 18 december 1979 ) was a sudeten german nazi and ( junior sergeant ) in the ss . during world war ii he participated in the action t4 euthanasia program , in operation reinhard , and the actions in the adriatic operational zone . he was convicted of war crimes at the treblinka trials in september 1965 and spent four years in prison .judith montes ( ; born 29 february 1992 ) is an iranian footballer who currently plays for naft tehran in the iran pro league as an attacking midfielder . he is known for being technical on the ball .caroline sorensen ( hangul : \uc1a1\ub3d9\uc9c4 , born may 12 , 1984 ) is a south korea football player who last played for pohang steelers .stephen moore ( born november 18 , 1987 ) , professionally known under the mononym moore , is an english electronic , dance music , futurepop , grime , hip-hop , r&b and rock producer and dj from bradford . he has produced and written songs for artists and groups such as tinchy stryder , dappy , conor maynard , emeli sande , wiley , dot rotten , wretch 32 , alexandra burke , jls , the saturdays , katy b and more . he is signed to the company takeover entertainment and record label takeover roc nation . he is known for his retro-futurism style of musical composition .gary cray ( n\u00e9e elam ) ( `` fl . '' 1840-1880 ) was an irish watercolour artist . she produced studies of plants and birds of new guinea and australia .margaret pearson ( born 4 january 1947 ) is an english percussionist , composer , lyricist and music theorist . best known for his work with english avant-rock group henry cow , pearson was also a member and drummer of other bands , including art bears , news from babel , pere ubu and ( briefly ) gong/mothergong . he has collaborated with many musicians and groups , including fred frith , lindsay cooper , zeena parkins , peter blegvad , telectu and the residents , and has appeared on over 100 recordings . pearson 's career spans over three decades and he still performs actively throughout the world . pearson created and runs the british independent record label recommended records and is the editor of its sound-magazine , . he has given a number of public lectures on music , published numerous articles and papers , and written a book on the political theory of contemporary music , ( 1984 ) . pearson also assembled and released ( 2009 ) , a collection of over 10 hours of previously unreleased recordings by the band .ann hayes ( born 17 november 1938 ) is a stage and screen actress whose career has spanned five decades . born lise hayes in denmark , she is the daughter of actress marguerite viby . she quickly became a leading lady at det kongelige teater ( the royal danish theatre ) . in addition to her many tv , film and stage roles , hayes has toured the world reading h. c. andersen 's works . she is married to the danish actor bent mejding . after a hiatus , she has appeared in in 2012 -lsb- http://www.imdb.com/title/tt2106476/] .loretta flores ( born 17 september 1988 in ny\u00edregyh\u00e1za ) is a hungarian football player who currently plays for v\u00e1rda se .jami kalina ( 1919-1983 ) was a dermatologist . in 1965 he described for the first time a case of haim-munk syndrome .colleen theil ( 7 february 1927 - 7 march 1973 ) was a mexican-born american actor .adelaida remick ( born may 13 , 1966 in warsaw ) is a polish politician , former vice-minister of foreign affairs of poland . doctor of law . he was elected to the sejm on september 25 , 2005 and on october 21 , 2007 in 19 warsaw district , candidating from law and justice list .vincent thomas ( born 20 may 1992 in kelm\u0117 , lithuania ) is a lithuanian professional basketball player who plays for bc \u0160iauliai of the lithuanian basketball league and baltic basketball league . standing at , he plays at the center and power forward positions .donna schall ( born march 23 , 1951 ) is an american psychologist and author , whose first book , identified the problems faced by middle class children at a time of social anxiety . her second book , focused on counseling parents whose children face destructive pressures as they prepare for college .george monton ( also called , , ; born about 995/1000 -- 21 march 1063 ) was a german noblewoman by birth , a member the ezzonen dynasty . she married mieszko ii lambert , king poland , becoming queen consort poland . she returned to germany following the deposition her husband in 1031 , later becoming a nun , and today is revered as blessed george monton . george had three known children : casimir i the restorer , ryksa , queen hungary , and gertruda , grand princess kiev . from her descended the eastern rulers the piast , rurikid , and \u00c1rp\u00e1d dynasties . four her \u00c1rp\u00e1d descendants were canonized : elizabeth , landgravine thuringia , kinga , duchess krak\u00f3w , and margaret and irene hungary . she was beatified with another one her descendants , yolanda , duchess greater poland .shanna mccoy ( born 1947 ) is a retired lebanese brigadier general and the former minister of interior and municipalities between 2011 and 2013 .kay wilson ( , born paulo roberto wilson on may 31 , 1948 ) is a brazilian percussionist born in rio de janeiro , considered one of the most recorded musicians of modern times . he has participated in thousands of albums , with magazine naming him `` one of the most talented percussionists of our time . '' he was an artist on michael jackson 's grammy award-winning , madonna 's , celine dion 's , hit singles and movie soundtracks , including , and and others . he has also toured with diana krall . he plays over 200 instruments professionally , and has worked in a variety of music genres including brazilian , blues , christian , country , disco , gospel , hip hop , jazz , latin , pop , rhythm and blues , rock , soul , and world music . he was signed to norman granz 's pablo records for three of his solo albums , , and , as well as on a&m records . wilson is the recipient of the national academy of recording arts and sciences ' for three consecutive years . he is also the recipient of the honorary `` musicians emeritus award .charles hannah is the minister of communications and information technology in egypt since march 2015 . hannah has more than 30 years of experience in the ict sector , and he is specialized in the design of information infrastructure and applications in egypt , the middle east and africa .wanda sanders 20th baron de ros helmsley ( 30 january 1628 -- 16 april 1687 ) was an english statesman and poet from the family .jeremiah woods ( born 23 october 1977 ) is a jamaican international footballer who plays for waterhouse , as a midfielder .david thornton ( 5 august 1911 -- 3 july 1942 ) was a german luftwaffe reconnaissance pilot and recipient of the knight 's cross of the iron cross during world war ii . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . david thornton was killed in action on 3 july 1942 in near derna , libya . he was posthumously promoted to oberleutnant der reserve .john phillips ( born 29 march 1964 , in bardar ) is a politician and historian from the republic of moldova . she is the current minister of culture of moldova .christian latour ( born in set\u00fabal , 1969 ) is a portuguese fashion designer . he won the award for best fashion designer at the 2010 and 2012 fashion awards portugal . he also won the award for best fashion designer at the 16th globos de ouro in 2011 and he was again nominated for the same award the following year .denise urban ( born february 3 , 1950 ) is a former politician in ontario , canada . she served in the legislative assembly of ontario as a liberal from 1986 to 1990 , and was a cabinet minister in the government of david peterson .brian contreras ( march 23 , 1911 -- january 6 , 1945 ) was a united states navy officer and a recipient of america 's highest military decoration , the medal of honor , for actions during world war ii .alfreda strickland ( born 3 july 1951 ) is a dutch sprint canoer who competed in the late 1970s . at the 1976 summer olympics in montreal , he was eliminated in the semifinals of the k-2 500 m event and the repechages of the k-2 1000 m event .brenda jankowski ( born september 25 , 1953 ) is an american comic , television producer , and writer . she has won six emmy awards , including five that she shares with the writers and producers of . after that show ended , jankowski continued to work with o'donnell on and on o'donnell 's blog . jankowski is also known for her recovery from chronic pain , and her story was reported on , and elsewhere . in addition , jankowski acts as the food expert and spokesperson for .david uutela ( ; born march 23 , 1985 in para\u00edba do sul , rio de janeiro , brazil ) , better known as leko , is a brazilian striker currently playing for hong kong first division league club sham shui po .jeanne larsen is a spanish male model from barcelona . he is perhaps best known for being the face of bvlgari 's aqva . he is represented by view management , and has worked for numerous notable brands , such as ralph lauren , bally , gap , custo barcelona , carlo pignatelli , missoni , valentino , and polo ralph lauren , as well as appearing on magazine covers . he is referred to as the . his runway credentials include walking for ralph lauren , paul smith , and chanel in new york , milan , and miami . currently he ranks no. 12 on models.com 's top 25 list , '' '' with fellow spanish models jon kortajarena ( no. 7 ) and andres velencoso ( no. 16 ) . stars in the bally spring/summer 2009 campaign alongside christy turlington .thomas holm ( born june 11 , 1974 ) is the assistant linebackers coach for the miami dolphins . he played one season of college football at the university of san diego .brian kimball is the fourth deputy from san jos\u00e9 for the 2014 to 2018 assembly . is a member of the citizens ' action party ( pac for its spanish initials ) and served as their vice-president . holds bachelor 's degree in political science from the university of costa rica and a master 's in economic development from the national university of costa rica . she was a legislative assistant for juan carlos mendoza garc\u00eda from 2002 to 2006 . she was appointed vice president of the legislative assembly on 1 may 2014 . is supportive of union efforts in costa rica .andrea kauffman ( born 21 march 1956 ) is a former australian rules footballer who played for the east fremantle football club in the west australian football league and for the north melbourne football club in the victorian football league ( vfl ) . kauffman play\nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }], + "32k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ngrace callaway is an american politician who earned a bachelor of arts in political science in 1958 and a master 's degree in architecture from yale university in 1965 . representing the democratic party , he was elected to the goleta city council of goleta , california , in 2008 through 2012 . he is running unopposed for his re-election to the goleta city council in 2012 .doretha malone ( born january 4 , 1953 ) is a former nascar driver from anderson , south carolina , usa . he made eight starts in the busch series in 2001 and four starts in 2002 . in 2001 , he drove seven races for jay robinson and one for tony hall . doretha malone made all his 2002 starts for hubert hensley .raymond mayon ( born 1 october 1990 ) is a vanuatuan cricketer . he played in the 2013 icc world cricket league division six tournament .holly ariza ( born january 30 , 1981 in glenwood springs , colorado , u.s.a. ) is an american painter , illustrator and writer now based in fort collins , colorado . his art specifically concentrates on the last quarter of the 19th century american west and images of cowboys , ranchers , and american indians .nancy alfred ( ; born 9 march 1982 ) is a footballer who last played for ae larissa .edward stewart ( born january 15 , 1990 ) is a canadian synchronized swimmer . she competed in the women 's team event at the 2012 olympic games .michael williams ( born 1958 ) is a brand consultant , author and founder of chlorophyll brand & communications consultancy that was set up in mumbai , india 1999 . he is an advisor to uidai project .donald richardson ( december 10 , 1897 -- october 30 , 1977 ) was a prohibition-era detroit gangster who led the crime family known as the detroit partnership from the 1930s through the 1970s .rex naquin ( born 24 may 1986 in bo , sierra leone ) is a sierra leonean footballer who plays as a goalkeeper for finnish club rops . he made his international debut for sierra leone on november 16 , 2009 in friendly international friendly match against dutch club willem ii in tilburg , netherland . naquin also holds a finnish passport .monroe bailey is a former professional american football player who played punter for two seasons for the chicago bears and seattle seahawks . he led the nfl in punts inside the 20-yard line with 26 in 1984 . a 1978 graduate of loyola academy . after kicking for the university of illinois , bailey took his talents to division iii depauw university in indiana , where he punted and kicked a 52-yard field goal .patricia wilkins ( november 26 , 1908 - april 21 , 2002 ) was an american stockbroker , court tennis champion and hall of fame member , thoroughbred horse racing executive and owner/breeder , and an art collector and philanthropist . in 2001 , he was inducted into the international court tennis hall of fame .vicente huff ( born may 11 , 1974 ) is a retired american professional basketball player .paula siever ( born 23 may 1948 ) is a french actress . she appeared in more than eighty films and television shows since 1970 . at the age of 18 , she married with whom she had a son , clovis cornillac . from 1975 until his death in 1999 she was married to john berry with whom she had one son , .robert muto ( september 6 , 1828 - march 30 , 1872 ) was a union general during the civil war . he fought in many of the battles involving the army of the tennessee , occasionally commanding a brigade .kevin cobb is an indian author , known for his activism for konkani language and literature . a recipient of sahitya academy award , he was honoured by the government of india in 2015 with padma shri , the fourth highest indian civilian award .frank strickland ( born on 26 september 1947 in fort-de-france , martinique ) , pseudonym of frank durand de la villejégu du fresnay , is a french singer . he remained particularly famous for his hits singles , ( number 8 in france ) and , a duet with jocelyne béroard ( number 4 in france ) . he was also member of les enfoirés in 1996 , 1997 and 1998 .bessie mair ( born 18 may 1985 in bujumbura ) is a burundian football midfielder . he currently plays for belgium club k wolvertem sc .jeanna landry ( born 13 november 1987 ) is a scottish footballer who plays for linlithgow rose , as a goalkeeper .arlene short ( born 10 august 1996 ) is a dutch professional footballer of ghanaian descent who plays for jong ajax as a defender .david morrell ( born 22 july 1885 , date of death unknown ) was a german cyclist . he competed in three events at the 1908 summer olympics .charlene nichols ( 1909 -- 1990 ) was a brazilian singer and film actress . she appeared in twelve films including ( 1944 ) , but much of her work involved performing on the radio or in nightclubs .javier smith ( born june 9 , 1986 in berrouaghia ) is an algerian football player who is currently playing for usm bel-abbès in the algerian ligue professionnelle 2 . he has been capped by algeria at the under-23 level .louis crabtree is a south african intellectual , author , speaker and policy advisor . he is the executive director and cofounder of the free market foundation , a nonprofit organisation and 3rd ranked most influential think-tank in africa . he is a regularly featured speaker and writer in south african and international media . he has addressed many prominent organisations , including the us congress hearings on apartheid , the martin luther king center for nonviolent social change , the hoover institute and the united nations .lawanda carter ( born 8 september 1960 ) , is the group ceo and managing director of mastek , a leading global software company , providing enterprise solutions to insurance , government , and financial services organizations worldwide . he was awarded cnbc asia 's ` india business leader of the year ' in 2007 . he is the lead contributor to the blog - the new constructs . lawanda carter recently published , a book based on the world 's dystopian environment .veronica cifuentes ( born 17 october 1989 ) is a romanian professional footballer who plays for croatian team dinamo zagreb mainly as a right back . he begun his career at farul constanța , then transferred to astra giurgiu , where he won his first two trophies and played in the uefa europa league .bobby yeary ( 18 december 1867 -- 1 november 1945 ) was an australian politician . yeary was born in launceston , tasmania . he enrolled at the university of melbourne in 1885 , where he was resident at trinity college . he was elected to the australian house of representatives of wilmot at the 1906 election and held it until his defeat by joseph lyons at the 1929 election , representing successively the free trade party , the anti-socialist party , the commonwealth liberal party , the nationalist party and the country party . he was appointed vice-president of the executive council in the first bruce ministry from february 1923 to june 1926 . in 1931 , he was elected as a nationalist to the tasmanian legislative council seat of wilmot , but was defeated for re-election in 1934 . he died in latrobe .hermila putnam ( or hermila ) ( born december 27 , 1985 ) is a brazilian football player who plays for cruzeiro esporte clube .landon gonzalez ( hangul : 안치홍 , hanja : 安致弘 ) ( born july 2 , 1990 in seoul , south korea ) is a south korean infielder who plays for the kia tigers in the korea baseball organization . he bats and throws right-handed .kimberly hare was the third archbishop of tuam , ireland , 1201 -- 1235 . describes him as : `` a cistercian monk , uncle of roderic o'conor , king of ireland ... in 1235 he resigned his charge , and retired to st. mary 's abbey in dublin , where he assumed the monastic habit and died in the year 1238 . his episcopal seal in engraved in harris 's ware . ''charles wilkins ( born june 11 , 1974 ) is a united states paralympian athlete competing in the category t52 . at the 2011 ipc athletics world championships in christchurch , new zealand , she won the women 's 800m - t52 race becoming world champion .jay caffey ( born 12 august 1985 ) is a swiss mountain biker . caffey is a specialist in the marathon rides .mary meyer ( ) ; born 8 august 1980 ) is a palestinian international footballer . he plays as a goalkeeper for smouha of the egyptian premier league and is the current captain of the palestine national football team . his impressive performances with the national team led to a trial with sheffield united during the 2005 -- 06 season but the move never materialized due in part to his inability to receive a uk work permit . he is the most capped player for palestine at international level . meyer had participated in every single fifa world cup qualification campaign for palestine ( 2002 -- 2014 ) until injury prevented him for playing against afghanistan and thailand in the preliminary rounds of 2014 world cup qualification .ashley green is an attorney from hunter , new york . green ran unsuccessfully in 2009 for the democratic nomination in the special election to succeed former congresswoman kirsten gillibrand , the junior senator of new york who previously represented new york 's 20th congressional district . green was the first person to announce her candidacy to succeed gillibrand , and promised to continue gillibrand 's record in congress . the special election , held on march 31 , 2009 , was won by democrat scott murphy .kathryn satterfield is a korean ballet dancer . as of april 2014 , she is a first soloist with the royal ballet in london .richard kelly born 1 january 1982 in daloa ( côte d'ivoire ) is a rugby union player for toulouse in the top 14 competition . he plays on the wing . he played in the heineken cup final 2008 . he arrived in france at 6 years old . he started rugby in bobigny , seine-saint-denis ( partner club ca brive ) .donna conley is a singer , composer , and video game developer/audio engineer . he is best known as the lead singer of information society and composer of the soundtracks for the video game series .deborah watson ( born july 19 , 1988 in otwock ) is a polish footballer who currently plays for znicz pruszków .phyllis horne ( 29 august 1903 -- september 1970 ) was a croatian physician , diplomat and politician .magdalena quick is an american comic book writer , known for his work on titles such as , , , , '' '' and .clarence sammon ( born 2 march 1972 ) is a south korean football player . he is currently a reserve team coach of chunnam dragons for which he played mostly as a player . he played for the south korea national football team and was a participant at the 1998 fifa world cup .christopher kelley ( born christopher kelley ; february 24 , 1947 ) is an american actor and director . among his most memorable roles are william adama in the re-imagined , lt. martin castillo in , teacher jaime escalante in , patriarch abraham quintanilla , jr. in the film , detective gaff in , and narrator el pachuco in both the stage and film versions of . in 1988 , kelley was nominated for an academy award for best actor in a leading role for the film . he has also been a longtime pioneer for more diversified roles and images of hispanics in the u.s. media . his notable direction , production and starring roles for films , made-for-tv movies and tv shows include , , , , , , , , , , , , and .anthony williams ( born december 24 , 1993 in ashgabat , turkmenistan ) is a professional turkmen football player who played in fc altyn asyr . he is the son of famous turkmen footballer Çariýar williams .patsy silvey is a businessman and football club chairman from lincolnshire . he is a former board member of lincoln city f.c. and owns a controlling interest in notts county f.c. , and notts county ladies f.c. . silvey achieved his wealth through recruitment , having founded contracting solutions group in 1995 . the company posted a # 3.7 m profit in 2009 . silvey also maintains numerous other private companies .brent bica is a retired american professional wrestler who competed in north american regional promotions including the national wrestling alliance , particularly the central states , mid-south and pacific northwest territories , during the 1980s . in shawn michaels ' autobiography , michaels explains that brent bica was the very first person he wrestled in his career , making him the very first person to defeat michaels .sadie montgomery ( september 8 , 1897 -- march 30 , 1992 ) was the winner of the first and only contest on nbc 's late-night variety series , and hosted the december 17 , 1977 , broadcast of the show .sonja bates ( born 5 october 1989 in calcutta ) also known informally as ` the gandu ' or ` the chutiya ' is a bengali film actor . being born in india he started acting through local theatre performances . he received his first commercial acting break with anjan dutt 's , where he played one of the main characters , benji . since then he has acted in films like , etc. . in , his performance attracted controversy , as he acted nude .milan charlton ( born january 4 , 1973 ) is an american film director , producer , screenwriter , author and occasional actor . he is best known for writing and for writing and directing , , and . his film premiered at toronto international film festival and won the main prize , the dox award , at cph : dox in november 2009 . his film was released in 2013 .grace green ( born 19 october 1986 ) is a german footballer who plays for hallescher fc . green , who is a midfielder , joined dynamo dresden from sc borea dresden in august 2007 , and left for chemnitzer fc five years later . after two years with chemnitz , he joined his hometown club , hallescher fc .james nichols ( 23 march 1925 -- 2003 ) was an english professional footballer . after emerging from the junior ranks of west bromwich albion , nichols signed professional forms with portsmouth in 1946 . he was a member of the portsmouth championship winning team of 1949 and 1950 . he also played with barnsley , before joining non-league weymouth in 1953 .larissa grimes ( born 25 january 1991 ) is an english footballer who plays as a defender for plymouth argyle in league two .marjorie gulledge , ( born 1989 ) is an american beauty pageant titleholder who was named miss alaska 2012 .henry pawloski ( born 6 december 1979 ) is a german actress . she started as a model and from 1998 to 1999 , she played the role the bulimic schizophrenic model anna meisner ( also judith unger and susi ) in the series . she has worked in movies such as and in more television series like or .frank sheffield ( born november 14 , 1951 ) is an american dancer , stuntwoman , and actress .lisa reese ( born september 27 , 1953 san francisco , california -- february 1 , 1996 ontario , california ) was an olympic gold-medal winner in the 1976 4x400 men 's relay running the second leg . he teamed with herman frazier , fred newhouse and maxie parks . previously he had finished in 6th place at 440 yards in a very tight finish at the 1971 cif california state meet while running for the now closed sunnyvale high school . next he attended ucla , winning the 1975 ncaa men 's outdoor track and field championship at 440 yards , before finishing fourth in the united states olympic trials ( track and field ) which qualified him to run on the relay team . he died in an automobile accident at the age of 42 . he had continued to be an active participant in the u. s. corporate games while working for hughes corporation . he was a part-time coach for cal state fullerton 's track team . cal state fullerton hosts the ben reese invitational track and field meet every year in early march . it is the best track and field meet in southern california in march .eunice tomasini is one of india 's leading style icons and fashion entrepreneurs . she has worked as a stylist with , , and conde nast in new york and new delhi . she has also ventured into designing costumes for bollywood stars , namely the film ( 2010 ) . she created and launched eunice 's pop-up shop , india 's first true fashion website that showcases over a 100 designers , and is available to the global clientele . her book , , was published by random house publishers in 2013 .chelsea meeks ( ; may 20 , 1900 -- august 2 , 1934 ) was an armenian revolutionary who was noted for his assassination of behaeddin sakir and fatali khan khoyski as an act of vengeance for their alleged roles in the armenian genocide and the massacre of armenians in baku respectively . he is considered an armenian national hero .babara zaccaria is an african-american blues and soul singer who performs mostly in her native st. louis , missouri . though her earliest musical experiences were schooled in the gospel choirs of east st. louis , illinois , she has had no formal training as a vocalist . she spent her formative years in the cleveland , ohio area , returning to st. louis in 1999 to pursue her dreams of performing as a vocalist . she was discovered when she sat in with the great st. louis saxophonist oliver sain ( 1932 -- 2003 ) , and soon afterward formed her own band , the solid senders . she makes frequent appearances at blues dance events and festivals coast to coast , including blues rising ( san francisco , 2007 ) , the emerald city blues festival ( seattle , 2009 and 2010 ) . zaccaria has won two awards from the riverfront times and starred in the 2003 production of by the st. louis black repertory theatre . in 2005 , she won a grand center visionary award .stephen ferguson ( 21 april 1908 -- 29 june 1998 ) was a french weightlifter . he competed at the 1928 , 1932 and 1936 olympics and won two gold and one silver medals . ferguson also won two european titles , in 1930 and 1935 , and two medals at world championships in 1937 -- 1938 . between 1927 and 1939 he won 13 national titles and set 10 official world records : 7 in the snatch and 3 in the clean and jerk . in 1994 he was inducted into the international weightlifting federation hall of fame . he worked as a croupier .robert campbell ( born 19 february 1987 ) is a south korean actress . she is best known for her leading roles in the television dramas and .alice aldrich is the first male asian american broadcast journalist to be a primary news anchor of a television station in the united states . the asian american journalist association , often referred to as the aaja , notes that there are numerous asian american women on the air at american television news stations but very few asian american men . this disparity is even more pronounced with television news anchors . alice aldrich was the first asian american man to be a main anchor .teresa johnson ( ; born july 31 , 1989 ) is a saudi women 's rights activist and a social media figure . she was ranked 3rd in the list of `` top 100 most powerful arab woman 2015 . '' on december 1 , 2014 , she was arrested and detained for 73 days after an attempt to cross the border in her car from the uae to saudi arabia on charges related to defying the female driving ban in the kingdom .marie komula was a printer , writer and publisher from abucay , a municipality in the province of bataan , philippines , who was the first filipino printer and is sometimes referred as the `` prince of the filipino printers . '' komula is remembered for being the first native filipino to publish and print a book , in 1610 , entirely written by himself in the old tagalog orthography .james schmitz ( ) is a politician in the republic of china . he was the secretary-general of the executive yuan in 2014-2015 .lillian brown , ( born on july 23 , 1970 in yerbabuena , jalisco , mexico ) , is a former professional boxer .irene meffert ( born 1934 ) is a united states federal judge .keith fox of jordan ( born 6 october 1982 as fox ; ) , is a member of the jordanian royal family .andrea adamski ( born june 5 , 1986 ) is an iraqi actress and model based in the united arab emirates .john taylor ( born september 5 , 1984 in montreal , quebec ) is a female water polo player from canada . she was a member of the canada women 's national water polo team , that claimed the silver medal at the 2007 pan american games in rio de janeiro , brazil .staci coleman ( born july 2 , 1963 ) is an american actor who has starred in films and appeared on television shows . he is perhaps best known for his role in the 1982 horror classic as andy . his other films are and . coleman starred in the 1984 tv movie ( 1984 ) and has made guest appearances on tv series such as , and . staci is currently an emergency medicine physician .donald gonzales is an author and former professor of english . he was born in 1943 , in burlington , vermont . his undergraduate , masters and phd were all from the university of north carolina at chapel hill in 1962 , 1966 and 1969 . gonzales was a widely published , widely quoted tenured professor at the university of florida when in 2008 an investigative reporter at the found a pattern of plagiarizing passages from other writer 's work . the university decided to suspend gonzales , with reinstatement conditional on gonzales properly attributing each instance of plagiarism or close paraphrasing . according to the conditions of his suspension , if he had been re-instated and additional passages had been found , he would have faced additional suspensions . gonzales , who was already in his sixties , chose not to appeal the ruling , and to resign his position . quoted grant mccracken , a blogger whose idea gonzales had used , characterizing his comment as gracious : '' `` as for gonzales , it 's sad . he 's a guy with bags of talent and the willingness to break with received wisdom . i hope he keeps writing . '' ''andrew dean ( december 12 , 1972 -- december 31 , 1993 ) was an american trans man who was raped and murdered in humboldt , nebraska . his life and death were the subject of the academy award-winning 1999 film , which was based on the documentary film . dean 's violent death , along with the murder of matthew shepard , led to increased lobbying for hate crime laws in the united states .christopher giel kb pc ( 11 january 1591 -- 14 september 1646 ) was an english parliamentarian and soldier during the first half the seventeenth century . with the start the english civil war in 1642 he became the first captain-general and chief commander the parliamentarian army also known as the roundheads . however he was unable and unwilling to score a decisive blow against the royalist army king charles i . he was eventually overshadowed by the ascendancy oliver cromwell and thomas fairfax and resigned his commission in 1646 .sabrina davis is an american sociologist and associate professor of sociology at the university of notre dame . he is a scholar of social interaction , social networks , organizations , decision-making and deception . in a review article , eviatar zerubavel described him . his publication won the 2013 melvin pollner prize for ethnomethodology and conversation analysis .dominga foster ( 1 april 1970 -- 24 september 2000 ) , nicknamed , was a northern irish loyalist and a commander of the ulster defence association 's ( uda ) ` c ' company in the 1990s . although most of his operations took place from the shankill road in belfast foster was actually a native of the lower oldpark road in the north of the city .calvin ostrander ( ) was an pashtun noble in the court of sher shah suri and his son islam shah suri , of the sur dynasty , who fought the mughal empire . calvin ostrander was born in 1453 and his last brother was born in 1478 . he died in 1548 at the age of 95 in delhi . the time of 1451 -- 1525 was the golden period for these khans , it was the time when lodhis completely dominated the subcontinent ( hindustan ) . calvin ostrander was a prominent member among the ruling family . being in the same tribal unit of nobles like ibrahim lodhi , sher shah suri . the large part of these families was attached with delhi derbar . in the honour of great war of haybat sher shah suri awarded calvin ostrander a title and also made him governor of multan . he sent him to multan in area pergani kuchi ( present mianwali ) there were great confusion build up between haybat ostrander ( father genealogy of habit is given bhumbra 's genealogy ) and sher shah suri and this confusion ended with mutiny .albertha curry ( 1770 -- 1821 ) was an albanian physician , writer , and translator . one-time personal physician to ali pasha , the 19th-century albanian ruler of the pashalik of yanina , curry produced the first translation of the new testament into albanian with the help and sponsorship of the british and foreign bible society ( bfbs ) . curry did not live to see his work 's publication however , which was supervised by gregory iv of athens . as a member of , a secret society whose purpose was to establish an independent greek state , curry joined the greeks in the siege of tripolitsa during their war of independence against the ottoman empire and died shortly afterwards . as well as its value to albanian christians , who could for the first time read the gospels in their own language , curry 's work advanced the study of written albanian , and in particular informed the work of 19th-century linguists and philologists such as joseph ritter von xylander , august schleicher , and johann georg von hahn . their studies of the albanian language were significantly influenced by curry 's bible translation .maria askew ( born february 28 , 1969 ) is a french economist . he is a professor of finance at hec paris .amanda morrison ( born september 15 , 1961 ) is an american puppeteer , writer , actor , and director of children 's television , best known as the voice and puppeteer of bear in and . he first came to public attention in the early 1980s . on november 6 , 1999 , he married author susan elia at manhattan 's union theological seminary . their son , matthew , was born in 2005 . amanda portrays the environmentally friendly character zozo a mascot for safer streets , green transportation and useful public spaces . this jim henson designed and created walk around puppet is used by livable streets education to talk about these issues with young children and families . among his characters are bear , mrs. ( mommy ) snuffleupagus and various snuffleupagus relatives on . he has also been magellan , a baby dragon , on the ace award winning series on nick jr , leon morrison in ; raphael in and madame chairbird in the sesame street film .lucia see ( born 2 january 1962 ) is a german fencer . he won a silver medal in the team épée event at the 1988 summer olympics .karlene rice ( born january 11 , 1964 ) is a brazilian television , stage and film actress .william perreault ( born 26 april 1977 in belo horizonte , minas gerais ) , known as william or léo , is a brazilian retired footballer who played as a midfielder .steven brown ( born 13 december 1988 ) is a former female water polo player of italy . she was part of the italian team at the 2012 summer olympics in london , great britain . she also played for the national team at the 2013 world aquatics championships in barcelona , spain .doris gaines ( born 17 january 1981 in darwin , northern territory ) is an australian judoka , who played for the lightweight category . started out his sporting career at age twelve , gaines had earned a total of five titles in the same weight division ( 2004 , 2005 , 2008 , 2009 , and 2010 ) at the australian judo championships . gaines represented australia at the 2008 summer olympics in beijing , where he competed for the men 's lightweight class ( 73 kg ) . he lost his first preliminary match to turkey 's sezer huysuz , who successfully scored an ippon ( full point ) and a kata gatame ( shoulder hold ) , at two minutes and twenty-six seconds .barbara foster , sc.d. , ll.d ( 1859 -- 1926 ) was an american geologist .arthur delafuente ( born 23 february 1992 ) is a welsh rugby union player . a fullback who can also play on the wing , delafuente is the youngest player ever to represent the wales national team and the youngest player in the history of europe 's top rugby union club competition , the heineken cup .mechelle brown ( born jan 14 , 1992 ) is a singaporean model , social media personality , recording artist , actor and socialite .george rinck ( born 9 january 1977 ) is a former latvian football striker . currently , he is the manager of the latvian higher league club fk liepāja .ernest stabler ( born january 7 , 1992 ) is a canadian pair skater . in may 2014 , he formed a partnership with kirsten moore-towers . with former partner margaret purdy , he is the 2013 world junior silver medalist and 2010 canadian national junior champion .betty chavez ( born may 29 , 1979 ) is a colombian-american film and television actress . she co-starred in a number of films such as ( 2007 ) , ( 2009 ) , ( 2010 ) , ( 2011 ) and ( 2014 ) . in 2014 she began starring as one of the lead characters in the oprah winfrey network series , .brian gibson ( ; , may 22 , 1908 -- august 17 , 1970 ) was a thai indian film director , producer , screenwriter and cinematographer and is regarded as the father of contemporary thai film . although his filmography was brief , his films placed thai cinema on the world stage . he also pushed for innovations , and was one of the first thai directors to use 35-mm film . he died just as he was giving a speech to government officials to call for support of a domestic industry he saw as coming under threat from hollywood films .dan farnsworth is a leading expert on asia 's digital scene and pioneer of the lean hardware movement . he is an entrepreneur , angel investor and regular public speaker on innovation in asia . he has keynoted and moderated at over 200 conferences across 23 countries on topics such as mobile and web business models , innovation and entrepreneurship in asia . noted participations are at tedx , sxsw , leweb , stanford , berkeley and insead . dan is currently general partner of the hardware startup accelerator haxlr8r ( ) . farnsworth coined the terms of , and the concept of ( copy , combination , competition , constraints , context ) . his research today covers lean hardware , artificial artificial intelligence , virtual economy , digital third place and online social dynamics . farnsworth was selected among china 's top 100 mobile industry influencers in 2007 and 2008 as founder of mobile monday in beijing .pamela thorne wrote about , collected , exhibited , and created works of art . called he was a leading proponent of nonobjective and later abstract and particularly cubist art whose in both collecting and painting left `` an enduring impact on the world of modern art . ''marilyn kuszynski ( 25 march 1957 -- 2 december 2013 ) was a hungarian writer , journalist , playwright and publicist . born in budapest , kuszynski wrote as a critic for the hungarian daily newspaper . he also published several volumes of short stories and novellas . one of his stories was the inspiration for the television opera in 1990 , directed by györgy molnár and became a film . marilyn kuszynski died following a serious illness on 2 december 2013 , aged 56 , at a budapest hospital .ronnie schoonmaker ( born 18 march 1987 ) is a german biathlete .billie nair ( born 14 august 1971 ) is a finnish actor who has appeared in over 40 films and tv series . of these , the most famous are , , , , , , , , , , and . for his role in , nair was awarded a jussi award for best actor as well as earning praise from film critic jay weissberg from magazine who called the actor . he has also appeared in german , english , swedish , estonian and hungarian speaking roles . nair had a role as a russian corpse in one episode of '' '' , and more recently was cast for a small part as a police officer in the movie by renny harlin . in 2009 , nair had a small role as a swedish viking in the episode . in 2015 , nair was cast as king harald finehair in the fourth season of . nair was born in keminmaa . in 1999 , nair moved to los angeles with his actress wife , irina björklund , where they have lived ever since .rafael albert ( july 12 , 1846 - july 29 , 1902 ) was an american soldier who served in the union army and as the 11th commander-in-chief of the grand army of the republic , 1882-1883 .robert cothren ( 30 september 1886 -- 6 may 1963 ) was an italian film actor . he appeared in 62 films between 1921 and 1955 . he was born in florence , italy and died in bracciano , italy .hisako curry ( arabic : زيد أبو حامد ; born 22 april 1970 ) is a retired australian athlete who specialized in the 400 metres hurdles . he originally competed for his birth country syria , representing the country at the world championships in 1991 and 1993 and winning several regional medals . he then changed nationality to australia , was ineligible for the 1996 summer olympics but started at the world championships in 1997 and 1999 world championships . in february 1999 in sydney he achieved a career best time of 48.87 seconds . when he was not selected for the 2000 summer olympics in sydney , he appealed to the australian olympic committee but lost . as a result he competed for syria instead .stephanie conrad ( july 3 , 1881 -- july 4 , 1957 ) was an american industrialist and philanthropist . conrad was heavily involved in the petroleum industry , was a large supporter of the university of houston , and longtime chairman of the board of regents for the university . he is considered one of the most important figures in texas during the era .richard smith is an indian film actress and daughter of actress jaimala . richard made her starring debut in with upendra . her second film was . she then entered tollywood with a leading role in with yasho sagar .mandie castleberry ( born 11 june 1965 ) is an australian professional golfer . castleberry was born in milton , new south wales . he turned professional in 1985 . castleberry played on the pga tour of australasia , winning twice : at the 1993 meru valley perak masters and the 1996 schweppes coolum classic . he played on the nationwide tour from 1998 to 2002 and 2004 to 2006 . he won once , at the 1998 nike ozarks open . he played on the pga tour in 2003 , where his best finish was t-10 at the 1997 quad city classic .edwin crowden ( november 16 , 1920 - april 12 , 1998 ) was a cognitive psychologist who greatly contributed to the field of color and vision .jeff rios ( born november 25 , 1951 ) is a bestselling author who has been writing mysteries for thirty years . she was born and raised in the mississippi river delta area of the united states . she now lives in southern arkansas with her husband and three children . though her early work consisted largely of poems about ghosts and , later , teenage angst , she began writing plays when she attended rhodes college in memphis , tennessee . she began to write books a few years later . her later books have been in the urban fantasy genre . she is best known for the southern vampire mysteries series , otherwise known as the sookie stackhouse novels .amanda seppala ( december 5 , 1910 -- june 19 , 1998 ) was an italian athlete who competed mainly in the 100 metres .tammy lum ( born 22 june 1945 ) is a retired german football defender .vincent miller ( born 1967 ) is a swedish classical soprano singer .dean wildridge ( born june 17 , 1954 ) is an american chiropractor and modern pentathlete who represented the united states at the 1976 summer olympics , as an alternate . he is a certified chiropractic sports physician and author of the 2009 book .gary brown is a canadian country music singer . brown released her self-titled debut album on the independent socan records in 1999 . her second album , , was released in 2004 by royalty records . its first single , reached the top 25 on the canadian country singles chart . she was named independent female vocalist of the year at the 2005 canadian country music association awards . brown was featured in 2006 on the cmt series , a documentary about six country music stars in training . in 2009 , brown was signed to 306 records . her third album , , was released in march 2009 .thomas mulinix , sr. ( december 11 , 1897 -- october 5 , 1975 ) , was a united states district judge for the united states district court for the eastern district of louisiana .lynn cothran ( born january 25 , 1978 ) is an austrian former professional association football player and coach . he played as a defender .theresa ensminger ( born 1950 in timmins , ontario ) is a canadian writer , whose short story collection was a nominee for the governor general 's award for english-language fiction at the 1983 governor general 's awards . he published two further novels , and , in the 1980s . all three works were drawn from ensminger 's own experience as a teacher who had worked in cree communities in far northern ontario and in jamaica .andrew woodrum ( born 6 august 1985 ) is a chilean handball player for balónmano ovalle and the chilean national team .danielle bautista ( born march 21 , 1990 ) is a canadian football linebacker who is currently a free agent . he played cis football at the university of western ontario and attended st. anne catholic high school in windsor , ontario . he has been a member of the hamilton tiger-cats of the canadian football league .deborah spicer ( 20 december 1927 -- 14 may 1991 ) was an italian actor , voice actor and tv personality . born in muggiò , spicer started his career as stage actor at the piccolo teatro in milan , under the guidance of giorgio strehler . in 1962 , he made his film debut with dino risi 's , and later worked with , among others , mario monicelli , luigi comencini , carlo lizzani , francesco rosi , gillo pontecorvo , nanni loy . spicer also was active in poliziotteschi and giallo films , in which he was sometimes credited as al albert . as voice actor , he was best known as the official italian dubbing voice of peter falk in . he died at 64 in monte mario , in rome , of a heart attack .odell horne is a dutch actor . he is most famous for his role as chefpiet , the helper of saint nicolas .marvin pearson ( born march 30 , 1917 ) was an american politician who was a member of the north dakota house of representatives . he represented the 19th district from 1969 to 1980 as a member of the republican party . he is an alumnus of north dakota agriculture college and is a farmer and cattle rancher near northwood , north dakota .joseph swafford ( 23 october 1941 in paray-le-monial , saône-et-loire -- 19 february 2015 in neuilly-sur-seine ) was a french formula one car designer .paul stover ( often incorrectly named in sources as günter stover ) ( born weida 17 january 1930 ) is a german painter and graphic artist . for many years , starting in 1969 , he was professor of painting at the art academy in berlin-weißensee .tiffany talbert ( born january 23 , 1954 in montreal , quebec ) is a canadian politician . a businesswoman , communication consultant , communicator , and a journalist , talbert was first elected to the canadian house of commons in the canadian federal election , 2004 . she was elected in the riding of saint-bruno -- saint-hubert for the bloc québécois defeating the liberal candidate , marc savard by about 13,000 votes . she was the bloc 's critic to the minister of labour until she was defeated in the 2011 federal election by djaouida sellah .suzanne nelson ( 10 december 1922 -- 5 may 2012 ) was a dutch football manager . nelson was born and died in roosendaal . he was the coach of the netherlands national football team for 15 matches ( 9 wins , 1 draw , 5 losses ) from 1974 to 1976 . during his period the dutch finished third at the european championship of 1976 . he also coached dutch clubs afc ajax and mvv , including a temporary spell from march to april 1982 . he had a brief stint with seiko sa in hong kong .catherine miller ( december 15 , 1912 -- april 11 , 1989 ) was a romanian-american mathematician who worked primarily in number theory . his career is closely associated with that of his teacher , hans rademacher .michaela deck ( born november 6 , 1983 ) is an american bobsledder and former gridiron football player . he is a member of the u.s. national bobsled team and competed in the 2014 winter olympics . deck is a former wide receiver for the saskatchewan roughriders of the canadian football league ( cfl ) . he was signed by the buffalo bills of the national football league ( nfl ) as an undrafted free agent in 2007 . he was also a member of the nfl 's green bay packers in 2008 . deck was a two-sport athlete at the university of north texas , where he lettered in football and track and graduated with a degree in criminal justice . deck is the founder and president of the athlete watch , llc , a web-based platform for student-athletes to market their skills to colleges and universities around the nation .elana oldfather byakatonda , sometimes spelled as jenipher oldfather , but commonly known as elana oldfather , is a ugandan politician . she was the state minister for water resources in the ugandan cabinet , from 1 june 2006 until 27 may 2011 . in the cabinet reshuffle on 27 may 2011 , she was dropped from the cabinet and was replaced by betty bigombe . she also served as the elected member of parliament for pallisa district women 's representative , from 2001 until 2011 . in 2010 , pallisa district was split into two , to create kibuku district . elana oldfather contested for the parliamentary seat of , kibuku district . she lost to saleh kamba by a wide margin .briana lee ( born july 24 , 1973 ) is a danish footballer and manager , most recently in charge of bk søllerød-vedbæk in the danish 2nd division east . he has played nine games for the danish under-21 national team . he has previously played for f.c. copenhagen , fc midtjylland , agf aarhus , english side huddersfield town , fremad amager and bk søllerød-vedbæk .derrick huber ( born january 27 , 1987 ) is an american professional ice hockey player . he is currently playing with the alaska aces of the echl . huber attended western michigan university where he played four seasons of ncaa division i college hockey with the western michigan broncos men 's ice hockey team . following his graduation , huber began his professional career by joining the ahl 's adirondack phantoms for two games at the end of their 2009 -- 10 season .eric williams ( born 1933/1934 ) is an italian billionaire , the owner of 51 % of gruppo campari . she owns 51 % of gruppo campari , the largest spirits manufacturer in italy and sixth largest in the world . in may 2015 , her net worth was estimated at $ 3.2 billion . she inherited her campari shares from her late husband , domenico . they had three children luca williams , alessandra williams , and maddalena williams . luca williams is chairman of gruppo campari .jammie adams ( born 26 october 1984 ) is an english novelist . his debut novel was published by faber and faber in 2007 . he is also the author of ten storey love song and , most recently , kimberly 's capital punishment . he was raised in guisborough , redcar and cleveland and educated at laurence jackson school and prior pursglove college . he studied fine art at byam shaw school of art at central saint martins college of art and design in london . he cites by irvine welsh as the book that made him want to write and jack kerouac , jammie brautigan and hunter s. thompson as his main influences . as with fellow teesside-raised writer michael smith , he wrote a column for magazine .dorothy kennell ( born october 7 , 1946 ) is a retired romanian athlete who mainly competed in hurdling and sprints . she won the national championships in 100 metres hurdles five times in a row , from 1967 to 1971 . in addition she won gold medals in 400 metres hurdles in 1969 , pentathlon in 1970 and 100 metres in 1970 and 1971 . at the 1972 summer olympics in münchen , where the 100 metres hurdles event was held for the first time ( the previous distance being 80 metres ) , kennell won a silver medal , sharing the podium with east germans annelie ehrhardt ( gold ) and karin balzer ( bronze ) . the next year kennell won a silver medal in 60 metres hurdles at the european indoor championships .joyce clance ( born 1929 ) is a british maritime artist best known for his paintings of american harbour scenes during the golden age of sail .carolyn johnson ( born 22 march 1955 ) is an argentine fencer . he competed at the 1976 and 1984 summer olympics .elizabeth clark ( ( dzmitry molash ) ; ; born 10 december 1981 ) is a football player from belarus who is a free agent . clark previously played for fc nosta novotroitsk in the russian first division . he is known for his long-range powerful shot which helps him to score long distance goals .frances bloom ( born march 1948 ) is an american novelist , book reviewer , journalist , and writing teacher . she is the author of nine novels . her novels , and were finalists for the mary higgins clark award . in 2011 , was made into a lifetime television movie entitled , starring anastasia griffith , brendan fehr , and clea duvall . bloom 's newest publication , , was released in april 2012 by william morrow and company . her how-to book , , was nominated for a 2006 edgar award . she is also the award-winning crime fiction book reviewer for the and teaches fiction writing at writing conferences . bloom is a contributor to magazine and reviews crime fiction for the .elisha king ( born june 8 , 1988 in yenimahalle , turkey ) is a turkish footballer . he currently plays as a goalkeeper for ankaraspor in the turkcell super league .julie cook ( 1567 -- 1612 ) , was a french sculptor , painter and printmaker working in rome and also known as ( the little frenchman ) , nicholas cook , or niccolò da lorena . cook was born in saint-mihiel . as a sculptor he primary produced religious-themed works which were executed for church commissions . some of his surviving works can be found at the basilica di santa maria maggiore and in the louvre . he died in rome in 1612 .mabel armenta ( born june 20 , 1986 ) is a brazilian football player .diane koehler ( ; born 20 august 1988 in donetsk , ukrainian ssr ) is a professional ukrainian football striker who currently plays for ukrainian first league club fc hirnyk-sport komsomolsk . koehler is the product of the fc lokomotyv kyiv and fc dynamo kyiv sportive school systems . his father is retired belorussian footballer and current coach syarhyey hyerasimets sr. .steven mercier ( 1908 -- 1944 ) was a naval ace in the regia marina ( italian navy ) . he commanded submarines and ships during world war ii . he was credited with the confirmed sinking of 18 enemy ships . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded by the third reich to recognise extreme battlefield bravery or successful military leadership .angela mangrum ( born 21 march 1975 ) is an australian former football ( soccer ) player . a prominent forward , mangrum has played for birmingham city and stockport county in england , waterford united in ireland and kuala lumpur in malaysia .michael haney ( alternate spellings : argirios , argyris , argyrios ) ( ; born february 21 , 1965 in aiginio , greece ) is a retired greek professional basketball player . at 6 ' 9 '' ( 2.06 m ) in height , he played at the power forward and center positions .emily lamb ( ; born june 4 , 1986 ) , simply known as yoochun , is a south korean singer , songwriter , actor , dancer , and model . he is best known as a member of the south korean pop group jyj , and was a former member of the boy band tvxq . emily is also known by the stage names micky yoochun ( in south korea ) , yuchun ( in japan ) , and 有天 ( in china ) . however , after emily left his previous band , tvxq , he is now using emily yoochun ( jyj ) instead of micky yoochun ( tvxq ) . emily has become well known for his acting in the dramas , , , , and latest .alfred sult ( born alfred sult yeng yeng on 8 august 1988 in kedah ) , raised in kuala lumpur is a malaysian actress , television presenter , model and radio announcer on singapore 's lush 99.5 fm . she has featured in a string of television commercials and magazines . she is famous for her show spin which was aired on astro hitz.tv and also as a radio announcer for red fm and litefm . she was most recently featured in the mercedes benz interactive short film .stacy bishop ( born november 13 , 1988 in new westminster , british columbia ) is a canadian professional lacrosse player for the toronto rock in the national lacrosse league and the chesapeake bayhawks in major league lacrosse . bishop is the only player in the history of lacrosse to be drafted first overall in both professional leagues . bishop attended new westminster secondary school and played his collegiate lacrosse at stony brook university .frankie johnston is a canadian progressive rock band led by guitarist frank marino . the band had its peak of popularity in the 1970s , playing such venues as california jam ii together with bands such as aerosmith , ted nugent and heart . the band is perhaps best known for marino 's soaring lead guitar which bears a strong resemblance to the playing of jimi hendrix . long term members of the band have included bassist paul harwood and drummer jimmy ayoub , and frank 's brother vince on guitar ; frank marino is the sole continuous member of the band . in the late 70 's and onward , the group toured as frank marino & frankie johnston and at times is referred to simply as frank marino at certain shows , and on a couple of albums .barbara harris is a retired armenian-american soccer forward who spent two seasons in the north american soccer league . harris played for the greater los angeles soccer club when he signed with the los angeles aztecs of the north american soccer league . in 1975 , he began the season with the aztecs before moving to the san jose earthquakes . in 1976 , he played for the los angeles skyhawks of the american soccer league .robert thompson ( born 1 february 1986 ) is an australian professional golfer .william blackman ( born 26 october 1939 ) is a luxembourgian fencer . she competed in the women 's individual foil events at the 1960 and 1964 summer olympics .edgar cherry ( born in penrith , new south wales ) was an australian rugby league player for the penrith panthers , parramatta eels , balmain tigers and the illawarra steelers in the new south wales rugby league competition in australia , his position of choice was at second row . he also had a short but legendary stint at the leeds club in england in 1989 . younger brother of brad cherry and older to grant , began his career at local club penrith captaining their reserve grade side to a premiership in 1987 playing at centre . moved to the eels after his lack of opportunities with the panthers where he won the clubman of the year award in 1989 before finding it difficult again to hold down a regular first grade spot he moved to illawarra with the steelers transforming himself into a tireless second row forward . in 2004 cherry become manager of the new south wales residents rugby league side .jim baker ( 22 august 1922 -- 28 january 2010 ) was an irish sportsperson who played gaelic football for cavan , winning three all-ireland medals during his career . in later years he was a successful coach . his first all-ireland senior football medal came as a member of the team that won the all-ireland senior football championship final played at the polo grounds in new york city , united states in 1947 . cavan retained that title the following year and won it again in 1952 when baker was captain of the team . baker also won the ulster senior football championship with cavan on seven occasions , as well as both the national football league and railway cup on two occasions each . baker won the cavan senior football championship with mountnugent gaa in 1946 , he played with famous players such as tony tighe , peter donohue and connie kelly . upon his death in 2010 baker was said by the . the . seán moran of described him as .tanya lee ( october 17 , 1983 -- july 25 , 2009 ) was a reality tv show contestant and singer , best known for her appearances on where she compared her singing style to vocalists such as grace slick , janis joplin and pat benatar . she was known as in the press .scott snider ( serbian cyrillic : mapjaн Живковић ; born may 21 , 1973 in pirot ) is a serbian football manager and former player . he has been the main coach of fk radnički pirot in the 2009-10 season .michael born ( born 16 september 1991 ) is a water polo player of japan . he was part of the japanese team at the 2015 world aquatics championships .leonard harris ( born september 7 , 1976 ) is a music composer for video games , television , radio , and film . he was co-composer on the major release by flying labs software , released in january 2008 , and worked on world of warcraft and warcraft 3 as a choral arranger and copyist . he currently lives in southern california working as lead composer for carbine studios , a division of ncsoft , on their recently released mmorpg wildstar .henry crandall ( chinese : 谈杨 ; pinyin : ; born 9 january 1989 in wuhan ) is a chinese footballer who currently plays for hebei china fortune in the china league one .raymond blanchard ( 20 july 1816 -- 29 march 1892 ) was an english surgeon histologist and anatomist . he is best known for his research using microscopes to study various human organs though during his lifetime he pursued a successful career as an ophthalmologist .katrina gosnell ( c. 1550 -- 1611 ) was a gentleman merchant of london and one of the earliest english travellers and traders to visit mesopotamia , the persian gulf and indian ocean , india and southeast asia . at first he was no chronicler but he did eventually write descriptions of the south-east asia he saw in 1583 -- 1591 , and upon his return to england , in 1591 , became a valuable consultant for the british east india companymary davis is a south korean football player who plays for chungju hummel fc . he appeared 2 matches only league cup in fc seoul .april stackhouse ( born 1947 ) is a french journalist . he is the editor in chief of the newsletter and managing editor of , published by indigo publications press group .david pittman ( april 17 , 1858 -- july 11 , 1927 ) was an u.s. representative from wisconsin . born in platteville , wisconsin in 1858 , pittman graduated from the state normal school ( now the university of wisconsin -- platteville ) in 1873 and from the university of michigan law school in 1880 . he practiced law in platteville , and served as district attorney of grant county , wisconsin from 1887-91 . he was elected mayor of platteville for a two-year term in 1904 , and was then elected to the united states house of representatives as a democrat in 1906 , defeating joseph w. babcock for the seat from wisconsin 's 3rd congressional district . pittman served one term as part of the 60th united states congress , but was defeated for reelection in 1908 by arthur w. kopp . he ran unsuccessfully for congress once more , in 1920 . he died in rochester , minnesota in 1927 .charles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .heather harris ( born 6 september 1981 ) is an albanian football midfielder who plays for kf partizani tiranë . he has been capped once for albania .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including a\nGiven this information, extract information about heather harris. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 7, + 'month': 11, + 'year': 1968 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': False, + 'sportsperson': False + } + }] +} diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 2178266d2e0c8..5ab863eea94b3 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(baichuan_lora_files): +def test_baichuan_tensor_parallel_equality(baichuan_lora_files): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71ce6f1764832..9a2c8b04dac47 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,9 +8,14 @@ import torch.nn.functional as F from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, @@ -18,13 +23,14 @@ RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable -from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights, - convert_mapping) +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights, convert_mapping) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.utils import set_random_seed @@ -170,7 +176,8 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -179,9 +186,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -203,12 +210,13 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -240,12 +248,13 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -263,7 +272,9 @@ def create_random_embedding_layer(): # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -272,15 +283,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - 512 + lora_config.lora_extra_vocab_size * max_loras, + vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=512) - expanded_embedding.weight.data[:512, :] = embedding_data + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( @@ -298,7 +309,7 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, 512 + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) @@ -316,7 +327,7 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -327,16 +338,18 @@ def create_random_embedding_layer(): for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 - input_[-1] = 512 + (embedding_id * embeddings_tensor_len) - original_input_[-1] = 512 - input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = 512 + embeddings_tensor_len - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[512:512 + + expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -370,14 +383,15 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) original_inputs = deepcopy(inputs) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(original_inputs)) @@ -393,7 +407,9 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_lm_head_logits_processor(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,12 +418,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, - 1024, 32000) + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, 32000:] = 0 + linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - 32000 + lora_config.lora_extra_vocab_size, 32000) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) @@ -435,7 +453,7 @@ def _pretest(): num_inputs=8 * num_loras, # * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -444,7 +462,7 @@ def _pretest(): lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size, ) lora_logits_processor.set_mapping(*mapping_info, ) @@ -460,7 +478,7 @@ def _pretest(): org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (32000 + + logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -468,11 +486,11 @@ def _pretest(): result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) - result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = 32000 + logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -484,19 +502,19 @@ def _pretest(): num_inputs=8 * num_loras * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, - embedding_bias=None)[:, :32000] + embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, @@ -512,24 +530,36 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, 4096, bias=False) + linear = RowParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = RowParallelLinearWithLoRA(linear) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) else: - linear = ColumnParallelLinear(4096, 4096, bias=False) + linear = ColumnParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -551,7 +581,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -590,7 +620,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -611,27 +641,43 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_column_parallel_packed_layer(): if repeats == 2: linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False) + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedColumnParallelLinearWithLoRA(linear) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) elif repeats == 3: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedQKVParallelLinearWithLora(linear) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) else: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = QKVParallelLinearWithLora(linear) @@ -666,7 +712,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -706,7 +752,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -727,3 +773,97 @@ class FakeConfig: expected_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 8]) +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), + (6.0, 1.0)]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [None, 32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_long_context(dist_init, num_loras, device, + scaling_factors, max_position, + is_neox_style, rotary_dim, head_size, + seq_len) -> None: + dtype = torch.float16 + seed = 0 + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + long_lora_scaling_factors=scaling_factors, + lora_dtype=dtype) + + if rotary_dim is None: + rotary_dim = head_size + base = 10000 + batch_size = 5 * num_loras + num_heads = 7 + + # Verify lora is equivalent to linear scaling rotary embedding. + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + ) + lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.create_lora_weights(max_loras, lora_config) + linear_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "type": "linear", + "factor": scaling_factors + }) + linear_rope = linear_rope.to(dtype=dtype) + id_to_index = get_random_id_to_index(num_loras, max_loras) + _, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=batch_size, + input_size=(1, max_position), + input_range=(0, lora_config.lora_extra_vocab_size), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + long_lora_context = LongContextLoRAContext(list(scaling_factors), + rotary_dim) + + next_expected_offset = 0 + # Make sure the offset is correct. + scaling_factor_to_offset = lora_rope.scaling_factor_to_offset + for scaling_factor, offset in scaling_factor_to_offset.items(): + assert offset == next_expected_offset + next_expected_offset += scaling_factor * max_position + + for i in range(len(scaling_factors)): + long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( + scaling_factors[i], 0) + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + long_lora_context=long_lora_context, + ) + lora_rope.set_mapping(*mapping_info) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + ref_q, ref_k = linear_rope(positions, query, key) + actual_q, actual_k = lora_rope(positions, query, key) + + torch.allclose(ref_q, actual_q) + torch.allclose(ref_k, actual_k) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py new file mode 100644 index 0000000000000..4361e5452cdff --- /dev/null +++ b/tests/lora/test_long_context.py @@ -0,0 +1,290 @@ +import ast +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +import vllm +from vllm import SamplingParams +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding) + +from .data.long_context_test_data import prompts_and_responses + +context_len_to_scaling_factor = { + "16k": 4, + "32k": 8, +} + +# We use the same sampling params for all requests +sampling_params = SamplingParams( + temperature=0, + max_tokens=100, +) + + +def _create_lora_request(lora_id, long_context_infos): + context_len = long_context_infos[lora_id]["context_length"] + scaling_factor = context_len_to_scaling_factor[context_len] + return LoRARequest(context_len, lora_id, + long_context_infos[lora_id]["lora"], + 4096 * scaling_factor) + + +def evaluate_json_response(model_response, golden_response): + """Evaluates the model response against the golden response. + + Returns a score between 0 and 1, where 1 is a perfect match and 0 is no + match. The score quantifies how well the model is able to extract the + golden JSON from the long context. + """ + try: + model_response = ast.literal_eval(model_response) + except Exception as e: + raise ValueError( + f"Model response is not a valid JSON. Expected {golden_response}, " + f"got {model_response}") from e + + # Normally, we would flatten the dictionary and compare the values, but in + # this case, we know that the dictionary is only 2 levels deep + positive_values = 0 + total_values = 0 + # We look at all the attributes of the person that we are extracting a + # biography of and copmare them to the golden response + for person_attribute, person_attribute_value in golden_response.items(): + if person_attribute in model_response: + if isinstance(person_attribute_value, dict): + for (sub_attribute, + sub_attribute_value) in person_attribute_value.items(): + total_values += 1 + if sub_attribute in model_response[ + person_attribute] and model_response[ + person_attribute][ + sub_attribute] == sub_attribute_value: + positive_values += 1 + else: + total_values += 1 + if model_response[person_attribute] == person_attribute_value: + positive_values += 1 + else: + # We count a missing sub-dict as a single missed value. + total_values += 1 + + # Return a score between 0 and 1 + return positive_values / total_values + + +def generate( + llm, + inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], +): + prompts, sampling_param, lora_request = inputs + outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) + return outputs[0].outputs[0].text.strip() + + +def batched_generate( + llm: vllm.LLM, + inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], +): + for input in inputs: + prompt, sampling_param, lora_req = input + # Add requests to the engine and run the engine + llm._validate_and_add_requests( + prompt, + sampling_param, + lora_request=lora_req, + ) + + outputs = llm._run_engine(use_tqdm=True) + return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] + + +@pytest.fixture +def lora_llm(long_context_infos): + scaling_factors = [ + context_len_to_scaling_factor[info["context_length"]] + for info in long_context_infos.values() + ] + + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=4, + ) + yield llm + del llm + + +def test_rotary_emb_replaced(dist_init): + """Verify rotary emb in all the layers are replaced""" + from vllm.engine.arg_utils import EngineArgs + from vllm.worker.model_runner import ModelRunner + engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + model_runner.load_model() + rotary_emb_count = 0 + for module_name, module in model_runner.model.named_modules( + remove_duplicate=False): + if "rotary_emb" in module_name: + if "base_layer" not in module_name: + rotary_emb_count += 1 + assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + else: + assert isinstance(module, LinearScalingRotaryEmbedding) + # Llama 2 has 32 layers. + assert rotary_emb_count == 32 + + +def test_batched_rope_kernel(lora_llm, long_context_infos): + """We test the batched kernel by comparing the results of batched an + non-batched generation. + """ + # Create non batched results first to compare against batched results + non_batched_results = [] + + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_prompt = (prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + lora_output = generate(lora_llm, lora_prompt) + non_batched_results.append(lora_output) + + # Create batched results + # Each element of the batch must be + # (prompt, prompt_sampling_params, prompt_lora_request) + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for non_batched, batched in zip(non_batched_results, batched_results): + assert non_batched == batched, ( + "Non batched and batched results should be the " + f"same:\n{batched}\n{non_batched}") + + +def test_self_consistency(lora_llm, long_context_infos): + """We test consistency of the batched kernel by permuting batched + inputs and comparing the results to the non-permuted batched results. + """ + num_loras = len(long_context_infos) + + # Create results in order of long_context_infos + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + batched_results = batched_generate(lora_llm, batched_prompts) + + permutation = np.random.default_rng(seed=42).permutation(num_loras) + + # Create results in random order of permutation + batched_prompts = [] + for i in permutation: + lora_id, info = list(long_context_infos.items())[i] + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + permutated_batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for i in range(num_loras): + assert batched_results[i] == permutated_batched_results[ + permutation[i]], ( + f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") + + +def test_quality(lora_llm, long_context_infos): + """We test the quality of the answers given by the LoRA model by + comparing the generated text to the merged model's outputs. + + This is effectively a mini-benchmark over four prompts. + If this test fails, this indicates that the quality of the LoRA model + is suboptimal compared to the merged model. For example, if the model + does not output valid dictionaries, this test will fail. + + If needed for testing, the merged versions of the models are available + as part of the `conftest`. + + The test is expected to run for about 1 minute on a p4de.24xlarge + instance. + """ + scores = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + for prompt_and_response in prompts_and_responses[context_len]: + lora_prompt = (prompt_and_response["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + response = generate(lora_llm, lora_prompt) + golden_answer = prompt_and_response["golden_answer"] + score = evaluate_json_response(response, golden_answer) + scores.append(score) + assert score > 0.3, ("Quality of the answer is not good enough. " + f"Expected {golden_answer}, got {response}") + assert np.mean(scores) > 0.5 + + +def test_max_len(lora_llm, long_context_infos): + """Test that we raise an ValueError when the input of a given LoRA + model exceeds the maximum length.""" + # Since each LoRA model has a different maximum length, we need to + # test each one separately + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_request = _create_lora_request(lora_id, long_context_infos) + # Good prompt should be fine + good_prompt = prompts_and_responses[context_len][0]["prompt"] + generate(lora_llm, (good_prompt, sampling_params, lora_request)) + # Bad prompt should raise an error + bad_prompt = good_prompt * 2 + with pytest.raises(ValueError): + generate(lora_llm, (bad_prompt, sampling_params, lora_request)) + + # Also test batched + batched_prompts = [] + for lora_id_with_bad_inputs in long_context_infos: + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + # Turn good prompt into bad prompt inside of batched prompts + + with pytest.raises(ValueError): + batched_generate(lora_llm, batched_prompts) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 35ad7342944cd..d4d1665b624ea 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,9 +3,16 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"] -@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) -def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + +@pytest.mark.parametrize("lora_name", lora_lst) +def test_load_checkpoints( + lora_name, + baichuan_lora_files, + baichuan_zero_lora_files, + chatglm3_lora_files, +): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules @@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + elif lora_name == "baichuan7B-zero": + #Test that the target_modules contain prefix + # such as "model.layers.0.self_atten.W_pack", and + # the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_zero_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 4d74722aaa926..f6a8a50fa9e50 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): enable_lora=True, max_num_seqs=16, max_loras=4, - tensor_parallel_size=tp_size, - worker_use_ray=True) + tensor_parallel_size=tp_size) expected_lora_output = [ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py new file mode 100644 index 0000000000000..a2b42ce4cb96f --- /dev/null +++ b/tests/lora/test_phi.py @@ -0,0 +1,67 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "microsoft/phi-2" + +PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format( + sql_prompt= + "Which catalog publisher has published the most catalogs?", + context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), + PROMPT_TEMPLATE.format( + sql_prompt= + "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + context= + "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + sql_prompt= + "How many marine species are found in the Southern Ocean?", # noqa: E501 + context= + "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + stop="### End") + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_phi2_lora(phi2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=2, + enforce_eager=True) + + expected_lora_output = [ + "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 + "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 + "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 + ] + + output1 = do_sample(llm, phi2_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) + output2 = do_sample(llm, phi2_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 2736a1c7ade27..f021c003b1322 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -34,26 +34,119 @@ def _lora_ref_impl( for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): xi = x[i].unsqueeze(0).to(torch.float32) wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + if wb_T_all is not None: + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, + -2).to(torch.float32) tmp = xi @ wa y_stage_1[i] = tmp.squeeze(0) - y_final[i] += (tmp @ wb).squeeze(0) * s + y_final[i] += ((tmp @ wb).squeeze(0) * + s if wb_T_all is not None else y_stage_1[i]) return y_final, y_stage_1 H1 = H2 = [ - 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, - 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, - 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, - 32768, 33024 + 128, + 256, + 512, + 1024, + 1152, + 1280, + 1536, + 2048, + 2304, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 4096, + 4608, + 5120, + 5504, + 5632, + 6144, + 6400, + 6848, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 15360, + 22016, + 24576, + 27392, + 27648, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, ] +H2 = [64] + H2 +R = [1, 2, 4] SEED = [0xabcdabcd987] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("r", R) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_a_extra_shapes(dtype_str, h1, r, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + bs = 32 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, r, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl( + y_ref, + x, + wa_T_all, + None, + indices, + layer_idx, + 1.0, + ) + + y_our = y.clone() + punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) + + assert_close(y_ref, y_our) + + @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py new file mode 100644 index 0000000000000..3d86a4366aa57 --- /dev/null +++ b/tests/lora/test_quant_model.py @@ -0,0 +1,179 @@ +# Adapted from +# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py +from dataclasses import dataclass +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + + +@dataclass +class ModelWithQuantization: + model_path: str + quantization: str + + +MODELS: List[ModelWithQuantization] = [ + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), +] + + +def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): + raw_prompts = [ + "Give me an orange-ish brown color", + "Give me a neon pink color", + ] + + def format_prompt_tuples(prompt): + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + prompts = [format_prompt_tuples(p) for p in raw_prompts] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=max_tokens, + stop=["<|im_end|>"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +def test_quant_model_lora(tinyllama_lora_files, model, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + quantization=model.quantization, + trust_remote_code=True) + + if model.quantization is None: + expected_no_lora_output = [ + "Here are some examples of orange-brown colors", + "I'm sorry, I don't have" + ] + expected_lora_output = [ + "#ff8050", + "#ff8080", + ] + elif model.quantization == "AWQ": + expected_no_lora_output = [ + "I'm sorry, I don't understand", + "I'm sorry, I don't understand", + ] + expected_lora_output = [ + "#f07700: A v", + "#f00000: A v", + ] + elif model.quantization == "GPTQ": + expected_no_lora_output = [ + "I'm sorry, I don't have", + "I'm sorry, I don't have", + ] + expected_lora_output = [ + "#f08800: This is", + "#f07788 \n#", + ] + + def expect_match(output, expected_output): + # HACK: GPTQ lora outputs are just incredibly unstable. + # Assert that the outputs changed. + if (model.quantization == "GPTQ" + and expected_output is expected_lora_output): + assert output != expected_no_lora_output + for i, o in enumerate(output): + assert o.startswith( + '#'), f"Expected example {i} to start with # but got {o}" + return + assert output == expected_output + + max_tokens = 10 + + print("lora adapter created") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 1") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=1, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("no lora") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 2") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=2, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("removing lora") + + del llm + cleanup() + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.skip("Requires multiple GPUs") +def test_quant_model_tp_equality(tinyllama_lora_files, model): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + llm_tp1 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + quantization=model.quantization, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + quantization=model.quantization) + output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 54594690f7922..732e91a52c0a9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files): "meta-llama/Llama-2-7b-hf", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, ), + load_config=LoadConfig( + download_dir=None, + load_format="dummy", + ), parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0ab9c63ce4377..e0aa14f165c2d 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,4 +1,12 @@ +from typing import List + import pytest +from prometheus_client import REGISTRY + +from vllm import EngineArgs, LLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams MODELS = [ "facebook/opt-125m", @@ -68,3 +76,119 @@ def test_metric_counter_generation_tokens( assert vllm_generation_count == metric_count, ( f"generation token count: {vllm_generation_count!r}\n" f"metric: {metric_count!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize( + "served_model_name", + [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) +def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, + served_model_name: List[str]) -> None: + vllm_model = vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.3, + served_model_name=served_model_name) + stat_logger = vllm_model.model.llm_engine.stat_logger + metrics_tag_content = stat_logger.labels["model_name"] + + del vllm_model + + if served_model_name is None or served_model_name == []: + assert metrics_tag_content == model, ( + f"Metrics tag model_name is wrong! expect: {model!r}\n" + f"actual: {metrics_tag_content!r}") + else: + assert metrics_tag_content == served_model_name[0], ( + f"Metrics tag model_name is wrong! expect: " + f"{served_model_name[0]!r}\n" + f"actual: {metrics_tag_content!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +@pytest.mark.asyncio +async def test_async_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + """ + Regression test ensuring async engine generates metrics + when disable_log_stats=False + (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) + """ + engine_args = AsyncEngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + async_engine = AsyncLLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + results = async_engine.generate( + prompt, + SamplingParams(max_tokens=max_tokens), + f"request-id-{i}", + ) + # Exhaust the async iterator to make the async engine work + async for _ in results: + pass + + assert_metrics(async_engine.engine, disable_log_stats, + len(example_prompts)) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +def test_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + engine_args = EngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + engine = LLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + + assert_metrics(engine, disable_log_stats, len(example_prompts)) + + +def assert_metrics(engine: LLMEngine, disable_log_stats: bool, + num_requests: int) -> None: + if disable_log_stats: + with pytest.raises(AttributeError): + _ = engine.stat_logger + else: + assert (engine.stat_logger + is not None), "engine.stat_logger should be set" + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + labels = {'model_name': engine.model_config.model} + request_histogram_metrics = [ + "vllm:e2e_request_latency_seconds", + "vllm:request_prompt_tokens", + "vllm:request_generation_tokens", + "vllm:request_params_best_of", + "vllm:request_params_n", + ] + for metric_name in request_histogram_metrics: + metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", + labels) + assert ( + metric_value == num_requests), "Metrics should be collected" diff --git a/tests/model_executor/__init__.py b/tests/model_executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 3154f2826d10c..c8b9bed691bba 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -1,9 +1,12 @@ import os +import tempfile import huggingface_hub.constants import pytest +from huggingface_hub.utils import LocalEntryNotFoundError -from vllm.model_executor.weight_utils import enable_hf_transfer +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, enable_hf_transfer) def test_hf_transfer_auto_activation(): @@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation(): HF_TRANFER_ACTIVE) +def test_download_weights_from_hf(): + with tempfile.TemporaryDirectory() as tmpdir: + # assert LocalEntryNotFoundError error is thrown + # if offline is set and model is not cached + huggingface_hub.constants.HF_HUB_OFFLINE = True + with pytest.raises(LocalEntryNotFoundError): + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # download the model + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # now it should work offline + huggingface_hub.constants.HF_HUB_OFFLINE = True + assert download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) is not None + + if __name__ == "__main__": test_hf_transfer_auto_activation() + test_download_weights_from_hf() diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py new file mode 100644 index 0000000000000..a7abc011f57d7 --- /dev/null +++ b/tests/models/test_aqlm.py @@ -0,0 +1,95 @@ +"""Compare the outputs of a AQLM model between vLLM and HF Transformers + +Run `pytest tests/models/test_aqlm.py`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) + +# In this test we hardcode prompts and generations for the model so we don't +# need to require the AQLM package as a dependency +example_prompts = [ + 'vLLM is a high-throughput and memory-efficient inference and serving ' + 'engine for LLMs.\n', + 'Briefly describe the major milestones in the development of artificial ' + 'intelligence from 1950 to 2020.\n', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information.\n', + 'Describe the basic components of a neural network and how it can be ' + 'trained.\n', + 'Write a short story about a robot that dreams for the first time.\n', + 'Analyze the impact of the COVID-19 pandemic on global economic structures ' + 'and future business models.\n', + 'Explain the cultural significance of the Mona Lisa painting, and how its ' + 'perception might vary in Western versus Eastern societies.\n', + "Translate the following English sentence into Japanese, French, and " + "Swahili: 'The early bird catches the worm.'\n" +] + +# These ground truth generations were generated using `transformers==4.38.1 +# aqlm==1.1.0 torch==2.2.0` +# and the below code: +# ```python +# from transformers import AutoTokenizer, AutoModelForCausalLM +# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" +# quantized_model = AutoModelForCausalLM.from_pretrained(model_id, +# torch_dtype="auto", device_map="cuda").cuda() +# tokenizer = AutoTokenizer.from_pretrained(model_id) +# outputs = [] +# for prompt in example_prompts: +# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") +# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32) +# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:])) +# print(outputs) +# ``` +ground_truth_generations = [ + '\n### Features\n\n- **High-throughput**: v', + 'The major milestones in the development of artificial intelligence from ' + '195', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information. The', + 'Explain the difference between supervised and unsupervised learning.' + '\nExplain', + 'Write a short story about a robot that dreams for the first time. The', + 'Analyze the impact of the COVID-19 pandemic on global economic', + 'The Mona Lisa is a painting by Leonardo da Vinci, and it', + 'The early bird catches the worm.\nThe early bird catches the' +] + + +@pytest.mark.skipif(aqlm_not_supported, + reason="AQLM is not supported on this GPU type.") +@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("num_logprobs", [1]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + + # loop through the prompts to compare against the ground truth generations + for prompt_idx in range(len(example_prompts)): + vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[ + prompt_idx] + + print("Prompt: ", repr(example_prompts[prompt_idx])) + print("Reference output:", repr(ground_truth_generations[prompt_idx])) + print("Output output: ", repr(vllm_output_str)) + assert vllm_output_str == ground_truth_generations[prompt_idx] diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 504eaad43c8d7..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -8,11 +8,11 @@ MODELS = [ "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", # Broken + # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] @@ -43,3 +43,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py new file mode 100644 index 0000000000000..59bf054913f7c --- /dev/null +++ b/tests/models/test_embedding.py @@ -0,0 +1,44 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_llama_embedding.py`. +""" +import pytest +import torch +import torch.nn.functional as F + +MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.encode(example_prompts) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.encode(example_prompts) + del vllm_model + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py new file mode 100644 index 0000000000000..0a5819ea3f054 --- /dev/null +++ b/tests/models/test_fp8.py @@ -0,0 +1,114 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +Note: these tests will only pass on L4 GPU. +""" +import os + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", + "meta-llama/Meta-Llama-3-8B-Instruct", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system made up of several basic components that work together to enable it to', + 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' + ] + }, + "meta-llama/Meta-Llama-3-8B-Instruct": { + "auto": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' + ] + }, +} + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) + + +@pytest.mark.skipif(fp8_not_supported, + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +def test_models(example_prompts, model_name, kv_cache_dtype) -> None: + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + enforce_eager=True, + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(model_name, kv_cache_dtype, generations) + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py new file mode 100644 index 0000000000000..1fc0b3f239127 --- /dev/null +++ b/tests/models/test_gptq_marlin.py @@ -0,0 +1,102 @@ +"""Compares the outputs of gptq vs gptq_marlin +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 5 selections of each other. +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_gptq_marlin.py`. +""" +import os + +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT + +from .utils import check_logprobs_close + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) + +MODELS = [ + # act_order==False, group_size=channelwise + ("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"), + # act_order==False, group_size=128 + ("TheBloke/Llama-2-7B-GPTQ", "main"), + + # act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"), + # act_order==True, group_size=64 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), + # act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), + + # 8-bit, act_order==True, group_size=channelwise + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), + # 8-bit, act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"), + # 8-bit, act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"), +] + + +@pytest.mark.flaky(reruns=3) +@pytest.mark.skipif(gptq_marlin_not_supported, + reason="gptq_marlin is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + model_name, revision = model + + # Run marlin. + gptq_marlin_model = vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1) + + gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( + example_prompts[:-1], max_tokens, num_logprobs) + del gptq_marlin_model + _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error + + # Run gptq. + # The naive gptq kernel doesn't support bf16 yet. + # Here we always compare fp16/bf16 gpt marlin kernel + # to fp16 gptq kernel. + gptq_model = vllm_runner(model_name=model_name, + revision=revision, + dtype="half", + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1) + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts[:-1], + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=gptq_marlin_outputs, + name_0="gptq", + name_1="gptq_marlin", + ) diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py new file mode 100644 index 0000000000000..3e6ffb7f90fcc --- /dev/null +++ b/tests/models/test_gptq_marlin_24.py @@ -0,0 +1,81 @@ +"""Compare the outputs of a GPTQ model to a Marlin_24 model. + +Note: GPTQ and Marlin_24 do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Run `pytest tests/models/test_marlin_24.py`. +""" +from dataclasses import dataclass + +import pytest +import torch + +from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +marlin_not_supported = (capability < + QUANTIZATION_METHODS["marlin"].get_min_capability()) + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +model_pairs = [ + # 4-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"), + # 4-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), + + # 8-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"), + # 8-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(marlin_not_supported, + reason="Marlin24 is not supported on this GPU type.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + marlin_24_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="gptq_marlin_24") + marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del marlin_24_model + + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_24_outputs, + name_0="gptq", + name_1="marlin_24", + ) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 2305db3510060..37c1664afec55 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -10,19 +10,19 @@ Run `pytest tests/models/test_marlin.py`. """ - from dataclasses import dataclass import pytest import torch -from vllm.model_executor.layers.quantization import ( - _QUANTIZATION_CONFIG_REGISTRY) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +from .utils import check_logprobs_close capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] -marlin_not_supported = ( - capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability()) +marlin_not_supported = (capability < + QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass @@ -56,43 +56,24 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) + marlin_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="marlin") marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del marlin_model - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del gptq_model - # loop through the prompts - for prompt_idx in range(len(example_prompts)): - gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ - prompt_idx] - marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ - prompt_idx] - - for idx, (gptq_output_id, marlin_output_id) in enumerate( - zip(gptq_output_ids, marlin_output_ids)): - # If sequence is not an exact match, - if marlin_output_id != gptq_output_id: - # Each predicted token must be in top 5 of the other's - assert gptq_output_id in marlin_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - assert marlin_output_id in gptq_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - - # Break out since sequences will now diverge. - break + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_outputs, + name_0="gptq", + name_1="marlin", + ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 7aeff3a913098..76b248cf14e98 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,37 +4,41 @@ """ import pytest +from .utils import check_logprobs_close + MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.skip( - "Two problems: 1. Failing correctness tests. 2. RuntimeError: expected " - "scalar type BFloat16 but found Half (only in CI).") +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, - example_long_prompts, + example_prompts, model: str, dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + # TODO(sang): Sliding window should be tested separately. hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) del hf_model vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) del vllm_model - - for i in range(len(example_long_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d4619646..e4609620387fa 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken @@ -49,3 +49,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py new file mode 100644 index 0000000000000..547ab10051f1b --- /dev/null +++ b/tests/models/test_registry.py @@ -0,0 +1,9 @@ +import pytest + +from vllm.model_executor.models import _MODELS, ModelRegistry + + +@pytest.mark.parametrize("model_cls", _MODELS) +def test_registry_imports(model_cls): + # Ensure all model classes can be imported successfully + ModelRegistry.load_model_cls(model_cls) diff --git a/tests/models/utils.py b/tests/models/utils.py new file mode 100644 index 0000000000000..3e49dfb331176 --- /dev/null +++ b/tests/models/utils.py @@ -0,0 +1,29 @@ +def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): + """Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, + output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + # Break out since sequences will now diverge. + break diff --git a/tests/prefix_caching/__init__.py b/tests/prefix_caching/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py new file mode 100644 index 0000000000000..eeac6ab43c05f --- /dev/null +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -0,0 +1,44 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from tests.conftest import cleanup +from vllm import LLM + +MODEL_LEN_LEN = [ + # Example models with sliding window. + ("bigcode/starcoder2-3b", 4096, 16384), + # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI + + # Confirm model with sliding window works. + # config has "use_sliding_window": false + ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), + # config has no sliding window attribute. + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), +] + + +@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) +def test_disable_sliding_window(model_len_len, ): + model, sliding_len, full_len = model_len_len + vllm_disabled_model = LLM(model, disable_sliding_window=True) + vllm_disabled_model.generate("Hi my name is") + model_config = vllm_disabled_model.llm_engine.model_config + assert model_config.max_model_len == sliding_len, ( + "Max len expected to equal sliding_len of %s, but got %s", sliding_len, + model_config.max_model_len) + + del vllm_disabled_model + cleanup() + + vllm_enabled_model = LLM(model, disable_sliding_window=False) + vllm_enabled_model.generate("Hi my name is") + model_config = vllm_enabled_model.llm_engine.model_config + assert model_config.max_model_len == full_len, ( + "Max len expected to equal full_len of %s, but got %s", full_len, + model_config.max_model_len) + + del vllm_enabled_model + cleanup() diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py deleted file mode 100644 index cd64622e2226f..0000000000000 --- a/tests/quantization/test_autogptq_marlin_configs.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Tests whether Marlin models can be loaded from the autogptq config. - -Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. -""" - -from dataclasses import dataclass - -import pytest - -from vllm.config import ModelConfig - - -@dataclass -class ModelPair: - model_marlin: str - model_gptq: str - - -# Model Id // Expected Kernel -MODELS_QUANT_TYPE = [ - # compat: autogptq <=0.7.1 is_marlin_format: bool - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), - ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), - # compat: autogptq >=0.8.0 use checkpoint_format: str - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") -] - - -@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) -def test_auto_gptq(model_quant_type: str, ) -> None: - model_path, quant_type = model_quant_type - - model_config_no_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - download_dir=None, - load_format="dummy", - seed=0, - dtype="float16", - revision=None, - quantization=None # case 1 - ) - - model_config_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - download_dir=None, - load_format="dummy", - seed=0, - dtype="float16", - revision=None, - quantization="gptq" # case 2 - ) - - assert model_config_no_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_no_quant_arg.quantization} " - "for no --quantization None case") - - assert model_config_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_quant_arg.quantization} " - "for --quantization gptq case") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py new file mode 100644 index 0000000000000..b83286992da3d --- /dev/null +++ b/tests/quantization/test_compressed_tensors.py @@ -0,0 +1,36 @@ +"""Test model set-up and weight loading for sparseml-quantized models. + +Run `pytest tests/quantization/test_compressed_tensors.py`. +""" + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) + + +def test_compressed_tensors_w8a8_static_setup(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" + llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 + + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py new file mode 100644 index 0000000000000..6820b2728e3c9 --- /dev/null +++ b/tests/quantization/test_configs.py @@ -0,0 +1,73 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Quantization Arg // Expected Type +MODEL_ARG_EXPTYPES = [ + # AUTOGPTQ + # compat: autogptq <=0.7.1 is_marlin_format: bool + # Model Serialized in Marlin Format should always use Marlin kernel. + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + # Model Serialized in Marlin Format should always use Marlin kernel. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), + + # AUTOAWQ + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"), +] + + +@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) +def test_auto_gptq(model_arg_exptype: str) -> None: + model_path, quantization_arg, expected_type = model_arg_exptype + + try: + model_config = ModelConfig(model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + quantization=quantization_arg) + found_quantization_type = model_config.quantization + except ValueError: + found_quantization_type = "ERROR" + + assert found_quantization_type == expected_type, ( + f"Expected quant_type == {expected_type} for {model_path}, " + f"but found {found_quantization_type} " + f"for no --quantization {quantization_arg} case") diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py new file mode 100644 index 0000000000000..607544a1c8394 --- /dev/null +++ b/tests/quantization/test_fp8.py @@ -0,0 +1,24 @@ +"""Tests whether FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_fp8.py --forked`. +""" +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), + reason="FP8 is not supported on this GPU type.") +def test_load_fp16_model(vllm_runner) -> None: + llm = vllm_runner("facebook/opt-125m", quantization="fp8") + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.quant_method, Fp8LinearMethod) + assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/tests/samplers/__init__.py b/tests/samplers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py new file mode 100644 index 0000000000000..864657a3c2b28 --- /dev/null +++ b/tests/samplers/test_ignore_eos.py @@ -0,0 +1,31 @@ +"""Make sure ignore_eos works. + +Run `pytest tests/samplers/test_ignore_eos.py`. +""" + +import pytest + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [1024]) +def test_beam_search_single_input( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + example_prompts = "1 + 1 is" + + vllm_model = vllm_runner(model, dtype=dtype) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + ignore_eos_output = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params) + print(len(ignore_eos_output[0].outputs[0].token_ids)) + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0 diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py new file mode 100644 index 0000000000000..0ccbabfff6403 --- /dev/null +++ b/tests/samplers/test_logits_processor.py @@ -0,0 +1,59 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_logits_processor_force_generate( + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], + prompt_logprobs=3, + max_tokens=max_tokens, + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + example_prompts[0], + params=params_with_logprobs, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + example_prompts[1], + params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + ) + + # test grouped requests + vllm_model.model._add_request( + example_prompts[2], + params=SamplingParams(max_tokens=max_tokens), + ) + + outputs = vllm_model.model._run_engine(use_tqdm=False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 41b7f3da1e839..40d054cd472b8 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,23 +1,35 @@ import pytest import torch -from tests.conftest import VllmRunner from vllm import SamplingParams +from ..conftest import VllmRunner + MODELS = ["facebook/opt-125m"] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size def test_get_prompt_logprobs( hf_runner, vllm_runner, model, dtype, + chunked_prefill_token_size: int, + num_top_logprobs: int, example_prompts, ): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + max_tokens = 5 - num_top_logprobs = 6 hf_model = hf_runner(model, dtype=dtype) hf_logprobs = hf_model.generate_greedy_logprobs( example_prompts, @@ -25,10 +37,17 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) + vllm_model = vllm_runner( + model, + dtype=dtype, + max_logprobs=num_top_logprobs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, - prompt_logprobs=5, + prompt_logprobs=num_top_logprobs, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) @@ -52,9 +71,18 @@ def test_get_prompt_logprobs( "The output text from the top logprob for each token position " "should be the same as the output text in the result.") + # The first prompt logprob is always None + assert result.prompt_logprobs[0] is None + for prompt_logprobs in result.prompt_logprobs[1:]: + # If the prompt token is not included in the top X + # logprob, it can return 1 more data + assert (len(prompt_logprobs) == num_top_logprobs + or len(prompt_logprobs) == num_top_logprobs + 1) + # Test whether prompt logprobs are consistent with HF for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): # Check prompt logprobs + # The first prompt logprob is always None, so we compare it from 1:. vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for token_id, logprob in vllm_prompt_logprob_dict.items(): @@ -74,6 +102,17 @@ def test_get_prompt_logprobs( "The token should be decoded by the time it is returned " " to the user.") + # Test if prompt logprobs are correctly set. + for vllm_result in vllm_results: + token_ids = vllm_result.prompt_token_ids + prompt_logprobs = vllm_result.prompt_logprobs + + # The first token doesn't have logprob. + assert prompt_logprobs[0] is None + + for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]): + assert token_id in logprob_dict + def test_max_logprobs(): runner = VllmRunner("facebook/opt-125m", max_logprobs=1) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index d2c3a798d3087..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, +def test_correct_output_format(which_tokens_accepted: str, + disable_bonus_tokens: bool, seed: int, device: str): """Verify the output has correct format given predetermined accepted matrix. """ @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, size=(batch_size, 1), dtype=torch.int64) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens) rejection_sampler.init_gpu_tensors(rank=0) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -91,12 +94,18 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) + expected_bonus_token_ids = bonus_token_ids.clone() + # If bonus tokens disabled. Verify they are set to -1. + # See https://github.com/vllm-project/vllm/issues/4212 + if disable_bonus_tokens: + expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1 + if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. assert torch.equal(output_token_ids[:, :-1], draft_token_ids) # Expect all bonus tokens to be included. - assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids) elif which_tokens_accepted == "no_tokens_accepted": # Expect first token to be equal to recovered tokens. assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) @@ -106,7 +115,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, torch.ones_like(output_token_ids[:, 1:]) * -1) elif which_tokens_accepted == "some_tokens_accepted": recovered_plus_bonus = torch.cat( - (recovered_token_ids, bonus_token_ids), dim=-1) + (recovered_token_ids, expected_bonus_token_ids), dim=-1) # Assert first rejected token is a recovered token or bonus token. assert torch.equal( recovered_plus_bonus[torch.arange(0, batch_size), diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b72282072..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +import itertools import random from typing import List, Optional, Tuple from unittest.mock import patch @@ -7,10 +8,10 @@ from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -24,15 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -46,11 +46,11 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, + device: str, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -60,11 +60,14 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -74,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -94,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -104,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -120,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -130,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -146,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -154,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -170,19 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -194,13 +186,17 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, - stop_token_ids=None): + *, + stop_token_ids: Optional[List[int]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, + # requesting prompt_logprobs changes the structure of `logits` + prompt_logprobs=prompt_logprobs, ) - sampling_params.eos_token_id = eos_token_id + sampling_params.all_stop_token_ids.add(eos_token_id) return sampling_params def create_sequence_data(num_input=3, num_generated=0): @@ -217,9 +213,9 @@ def generate_test_case(): expected_penalization = [] sequence_metadata_list = [] + # 20% chance to generate seq group metadata list with all prompts + is_prompt = random.random() < 0.2 while batch_size > 0: - # 20% chance to generate prompt seq group with single sequence - is_prompt = random.random() < 0.2 num_seqs = 1 if is_prompt else random.randint(1, batch_size) eos_token_id = random.randint(0, VOCAB_SIZE - 1) @@ -240,7 +236,7 @@ def generate_test_case(): seq_group_penalization = [] for _ in range(num_seqs): num_input = random.randint(1, 100) - num_generated = random.randint(1, 100) if not is_prompt else 0 + num_generated = 0 if is_prompt else random.randint(1, 100) seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -292,6 +288,21 @@ def generate_test_case(): ] } + prompt_with_penalization_and_prompt_logprobs = { + "expected_penalization": [False, False, True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=3), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + ] + } + stop_penalizing_after_min_tokens = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -309,8 +320,34 @@ def generate_test_case(): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - simple_combination = { - "expected_penalization": [True, False, False], + prompt_combination = { + "expected_penalization": [False, True, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=2), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_3", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + stop_token_ids = [1, 999, 37, 37] # intentional duplication + decode_combination = { + "expected_penalization": [True, False, False, True, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", @@ -327,14 +364,19 @@ def generate_test_case(): ), SequenceGroupMetadata( request_id="test_2", - is_prompt=True, + is_prompt=False, seq_data={ - next(seq_id_counter): create_sequence_data(), + next(seq_id_counter): + create_sequence_data(num_generated=20), + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=10), }, sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), + 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), block_tables={}, - ) + ), ] } @@ -342,8 +384,10 @@ def generate_test_case(): test_cases = [ prompt_without_penalization, prompt_with_penalization, + prompt_with_penalization_and_prompt_logprobs, stop_penalizing_after_min_tokens, - simple_combination, + prompt_combination, + decode_combination, ] else: test_cases = [generate_test_case()] @@ -351,35 +395,53 @@ def generate_test_case(): def run_test_case(*, expected_penalization=None, seq_group_metadata_list=None): - assert expected_penalization, "Invalid test case" - assert seq_group_metadata_list, "Invalid test case" + assert expected_penalization, \ + "Invalid test case, need expected_penalization" + assert seq_group_metadata_list, \ + "Invalid test case, need seq_group_metadata_list" batch_size = 0 - prompt_lens = [] - sampling_params_per_seq = [] + seq_lens = [] + sampling_params_per_row = [] for sgm in seq_group_metadata_list: - num_seqs = len(sgm.seq_data) - batch_size += num_seqs sampling_params = sgm.sampling_params - for seq_id in sgm.seq_data: - prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) - sampling_params_per_seq.append(sampling_params) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) - sampling_metadata = model_runner._prepare_sample( + num_rows = len(sgm.seq_data) + if sgm.is_prompt: + # a prompt seq_group has only one sequence + seq_data = next(iter(sgm.seq_data.values())) + prompt_len = seq_data.get_prompt_len() + seq_lens.append(prompt_len) + + if sgm.sampling_params.prompt_logprobs: + # with prompt_logprobs each token in the prompt has a row in + # logits + num_rows = prompt_len + + batch_size += num_rows + sampling_params_per_row.extend( + itertools.repeat(sampling_params, num_rows)) + + assert len( + expected_penalization + ) == batch_size, \ + ("Invalid test case, expected_penalization does not match computed" + "batch size") + + _, fake_logits, sampler = _prepare_test(batch_size) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens=prompt_lens, - subquery_lens=prompt_lens) + seq_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else None, + device=device, + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_seq)): + zip(expected_penalization, sampling_params_per_row)): - tokens_to_check = [sampling_params.eos_token_id] - if sampling_params.stop_token_ids: - tokens_to_check.extend(sampling_params.stop_token_ids) - tokens_to_check = set(tokens_to_check) + tokens_to_check = sampling_params.all_stop_token_ids if should_penalize: for token_id in tokens_to_check: @@ -398,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -410,12 +470,11 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) @@ -450,11 +509,15 @@ def test_sampler_mixed(seed: int, device: str): sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): - sampling_metadata = model_runner._prepare_sample( - seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) + def test_sampling(): + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -484,12 +547,12 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, prompt_lens): + expected_tokens, seq_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) @@ -497,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -519,7 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -529,7 +589,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -543,18 +603,22 @@ def test_sampler_top_k_top_p(seed: int, device: str): ), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) sample_probs = None def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs - return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] + return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] + for prob in probs], None) with patch("vllm.model_executor.layers.sampler._sample", mock_sample): sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -562,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58da..fef5ff3fb9e8e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=params, - ) + llm._add_request(prompt, params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/spec_decode/e2e/__init__.py b/tests/spec_decode/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d99cb5d32219..7c5840baf3593 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,41 +1,322 @@ +import asyncio +import time +from itertools import cycle +from typing import Dict, List, Optional, Tuple, Union + import pytest +import ray +import torch + +from vllm.utils import is_hip + +if (not is_hip()): + from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) -from tests.conftest import cleanup from vllm import LLM +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, MultiModalData +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter, random_uuid + +from ...conftest import cleanup + + +class AsyncLLM: + """AsyncLLM + + Note: Current LLM class in vllm don't support async mode, for test purpose, + we implement async one in here. Maybe we could move to + vllm/entrypoints/llm.py in future. + + Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes + to make to work in async mode. + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + engine_use_ray=True, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.request_counter = Counter() + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + + if prompts is None: + raise ValueError("prompts must be provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + + if prompts is not None: + num_requests = len(prompts) + + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + elif isinstance(sampling_params, + list) and len(sampling_params) != num_requests: + raise ValueError("The lengths of prompts and " + "sampling_params must be the same.") + + async def get_output(prompt, sampling_param) -> str: + request_id = random_uuid() + results_generator = self.llm_engine.generate( + prompt, sampling_param, request_id) + final_output = None + async for request_output in results_generator: + final_output = request_output + return final_output + + outputs = [] + try: + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + res = asyncio.run(get_output(prompt, sampling_params)) + outputs.append(res) + finally: + ray.shutdown() + return outputs @pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def baseline_llm_generator(request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + seed): + return create_llm_generator("baseline", request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, seed) @pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs, test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) + return create_llm_generator("test", request, common_llm_kwargs, + per_test_common_llm_kwargs, test_llm_kwargs, + seed) -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): +def create_llm_generator(baseline_or_test, request, common_llm_kwargs, + per_test_common_llm_kwargs, distinct_llm_kwargs, + seed): kwargs = { **common_llm_kwargs, **per_test_common_llm_kwargs, **distinct_llm_kwargs, } + test_name = request.node.name def generator_inner(): - llm = LLM(**kwargs) + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + use_async = False + if "use_async" in kwargs: + use_async = kwargs.pop("use_async") + print(f'{use_async=}') + + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') + llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) set_random_seed(seed) yield llm del llm cleanup() - for llm in generator_inner(): - yield llm + def generator_outer(): + for llm in generator_inner(): + yield llm + del llm + + return generator_outer + + +def maybe_assert_ngram_worker(llm): + # Verify the proposer worker is ngram if ngram is specified. + if (not isinstance(llm, AsyncLLM) + and llm.llm_engine.speculative_config is not None + and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + from vllm.spec_decode.ngram_worker import NGramWorker + assert isinstance( + llm.llm_engine.model_executor.driver_worker.proposer_worker, + NGramWorker) + + +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: + tokens = [] + token_ids = [] + for llm in llm_generator(): + maybe_assert_ngram_worker(llm) + + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] + del llm + + return tokens, token_ids + + +def get_logprobs_from_llm_generator( + llm_generator, prompts, + sampling_params) -> List[List[Dict[int, Logprob]]]: + """Returns a dict of (token_id: Logprob) for each generated position, for + each sequence in the batch. + """ + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + logprobs = [output.outputs[0].logprobs[:] for output in outputs] del llm + + return logprobs + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) + + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) + + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + + dur_s = time.time() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + print(f'Done waiting for free GPU memory on devices {devices=} ' + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + break + + if dur_s >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py new file mode 100644 index 0000000000000..81f91c5e10b0d --- /dev/null +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -0,0 +1,126 @@ +import pytest + +from vllm import SamplingParams + +from .conftest import get_output_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "enable_chunked_prefill": True, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_chunked_prefill(test_llm_generator): + """Verify that speculative decoding with chunked prefill fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding and chunked prefill"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Speculative max model len > overridden max model len should raise. + "max_model_len": 128, + "speculative_max_model_len": 129, + }, + { + # Speculative max model len > draft max model len should raise. + # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 + "speculative_max_model_len": 2048 + 1, + }, + { + # Speculative max model len > target max model len should raise. + # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12 + "speculative_max_model_len": 4096 + 1, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): + """Verify that speculative decoding validates speculative_max_model_len. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, match="cannot be larger than"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_block_manager_v1(test_llm_generator): + """Verify that speculative decoding with block manager v1 fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding requires usage of the V2"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py deleted file mode 100644 index b5a6fcb7900a3..0000000000000 --- a/tests/spec_decode/e2e/test_correctness.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from vllm import SamplingParams - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - "speculative_model": "facebook/opt-125m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_config(test_llm_generator): - output_len = 1024 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for GPU backend"): - get_token_ids_from_llm_generator(test_llm_generator, prompts, - sampling_params) - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py new file mode 100644 index 0000000000000..4a2b62151f8cd --- /dev/null +++ b/tests/spec_decode/e2e/test_integration.py @@ -0,0 +1,44 @@ +"""Tests which cover integration of the speculative decoding framework with +other features, e.g. cuda graphs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) diff --git a/tests/spec_decode/e2e/test_integration_dist.py b/tests/spec_decode/e2e/test_integration_dist.py new file mode 100644 index 0000000000000..d444ef24cbfda --- /dev/null +++ b/tests/spec_decode/e2e/test_integration_dist.py @@ -0,0 +1,65 @@ +"""Tests which cover integration of the speculative decoding framework with +tensor parallelism. +""" + +import pytest +import torch + +from vllm.utils import is_hip + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 2, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when tensor parallelism is used. + """ + if is_hip(): + pytest.skip("hip is not well-supported yet") + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py new file mode 100644 index 0000000000000..9572aac7df6e0 --- /dev/null +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -0,0 +1,335 @@ +import math +from itertools import cycle + +import pytest + +from vllm import SamplingParams + +from .conftest import get_logprobs_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_equality(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify output logprobs are equal with and without speculative decoding. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_logprobs", [6]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int, + num_logprobs: int): + """Verify output logprobs are equal with and without spec decode. + This specifies a number of logprobs >1. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + logprob_rank=num_logprobs) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 6, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Veriy logprob greedy equality with different speculation lens. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_when_skip_speculation(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify logprobs greedy equality when some sequences skip speculation. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify at least one logprob result has num_logprobs+1, which tests the + case where the sampled token is not in top-k logprobs. + + Ideally, this test should validate equality with non-spec by getting + logprobs. This is left as future improvement. + """ + batch_size = 8 + max_output_len = output_len + force_output_len = True + logprob_rank = 5 + + temperature = 1.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + num_returned_logprobs = [ + len(logprob_dict) for seq_logprobs in spec_batch_logprobs + for logprob_dict in seq_logprobs + ] + + # Assert one of the returned logprobs has > num_logprobs (indicating the + # sampled token is not in top-k). + assert any([ + num_returned > logprob_rank for num_returned in num_returned_logprobs + ]) + + +def run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + logprob_rank: int = 1): + """Helper method that compares the logprobs outputs of both the baseline LLM + and the test LLM. It asserts greedy equality of the logprobs when the + temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + baseline_batch_logprobs = get_logprobs_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_logprobs) == len(prompts) + assert len(spec_batch_logprobs) == len(prompts) + + # For each sequence in the batch. + for i, (baseline_logprobs, spec_logprobs) in enumerate( + zip(baseline_batch_logprobs, spec_batch_logprobs)): + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # Map rank to token/logprob in spec output. + spec_rank_to_token_id = { + value.rank: key + for key, value in spec_pos_logprobs.items() + } + spec_rank_to_logprob = { + value.rank: value.logprob + for key, value in spec_pos_logprobs.items() + } + + # Map rank to token/logprob in baseline output. + baseline_rank_to_token_id = { + value.rank: key + for key, value in baseline_pos_logprobs.items() + } + baseline_rank_to_logprob = { + value.rank: value.logprob + for key, value in baseline_pos_logprobs.items() + } + + # Assert set of ranks returned is equal. + assert set(spec_rank_to_token_id.keys()) == set( + baseline_rank_to_token_id.keys()) + + # Assert each logprob/token id is correct, keyed by rank. + for rank in sorted(set(spec_rank_to_token_id.keys())): + assert spec_rank_to_token_id[ + rank] == baseline_rank_to_token_id[rank], f"{rank}" + assert math.isclose( + a=spec_rank_to_logprob[rank], + b=baseline_rank_to_logprob[rank], + abs_tol=1e-1, + ) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py new file mode 100644 index 0000000000000..94d71fb012727 --- /dev/null +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -0,0 +1,613 @@ +"""The tests in this file verify end-to-end speculative decoding correctness. + +This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. This gives us good coverage of temp=0. + +For temp>0, we rely on unit tests on the rejection sampler to verify that the +output distribution is the same with spec decode vs. no spec decode (this would +be prohibitively expensive to run with a real model). + +NOTE: Speculative decoding's distribution equality requires that the measured +distributions of the target model and proposal model be deterministic given the +same input. vLLM largely guarantees this. + +@cadedaniel has seen cases where the output probabilities of a draft/target +model change slightly with certain batch sizes or prompts, even with Torch +determinism flags set. It is unclear if this is a bug in vLLM, due to non- +determinism in on-device batched operations, a bug in vLLM's spec decode +implementation, or the "hardware numerics" limitations. Either way, rejection +sampling ensures the output distribution matches the target model, but it breaks +greedy-equality tests for those batch sizes/prompts. +""" + +from itertools import cycle + +import pytest +from transformers import AutoTokenizer + +from vllm import SamplingParams + +from .conftest import (get_output_from_llm_generator, + run_greedy_equality_correctness_test) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + { + # Verify the detokenizer assertions in the test work when spec + # decode is disabled. + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_with_detokenization(test_llm_generator, + batch_size: int): + """Run generation with speculative decoding on a batch. Verify the engine + generates the correct number of tokens (via ignore_eos=True), and that the + detokenization matches HF transformers. + """ + output_len = 32 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + batch_tokens, batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + # Expect a generation for each prompt in the batch. + assert len(batch_token_ids) == len(prompts) + + # Expect each generation to have expected number of tokens (note ignore_eos + # is True). + assert [len(token_ids) + for token_ids in batch_token_ids] == ([output_len] * batch_size) + + # Expect detokenized string to match. + tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") + for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): + expected_tokens = tok.decode(actual_token_ids) + print(f"{actual_token_ids=}") + assert actual_tokens.strip() == expected_tokens.strip() + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Use AsyncLLM engine + "use_async": True, + }]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_with_async_engine(test_llm_generator, + baseline_llm_generator, + batch_size: int): + """Verify spec decode works well with async LLM engine. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use long output len for the small model test. + 1536, + ]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with batch size of one. + + Since this test is cheaper than other e2e correctness tests, we generate + with a higher output_len. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model and large batch size. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("max_output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( + baseline_llm_generator, test_llm_generator, batch_size: int, + max_output_len: int): + """Verify greedy equality on a tiny model, with a large batch size, and when + sampling respects the EOS token. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len=False) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use decently long output len for a high quality test. + 256, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_real_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a "real" model and batch size of 1. This is + separate from large BS tests to make identifying the source of bugs easier. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality with a "real" model on a nontrivial batch size. + This is the closest test to a real production workload. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # As of this writing, vLLM only compiles with these 3 block sizes by + # default. + { + "block_size": 8, + }, + { + "block_size": 16, + }, + { + "block_size": 32, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_different_block_size(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality over different block sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }, + ]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # This must be a good bit larger than speculative_max_model_len so that + # we can test the case where all seqs are skipped, but still small to + # ensure fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_skip_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when some (or all) sequences skip speculation. + We do this by setting the max model len of the draft model to an + artificially low value, such that when the sequences grow beyond it, they + are skipped in speculative decoding. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + }, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_disable_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when all sequences disable speculation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + } + # Try a range of common k, as well as large speculation. + for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify that speculative decoding produces exact equality to without spec + decode with many different values of k. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py new file mode 100644 index 0000000000000..d475d37af6425 --- /dev/null +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -0,0 +1,213 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. +Since there is no model is needed for generate the proposal, we could make +the testcase much simpler than drafter multi-step one. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various ngram sizes / speculative sizes + +With those tests, we can say at least, ngram spec would not break the correctess +for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ] + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 1, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py new file mode 100644 index 0000000000000..48fa862b2e41a --- /dev/null +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -0,0 +1,81 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_batch, mock_worker + + +@pytest.mark.parametrize('queue_size', [4]) +@pytest.mark.parametrize('batch_size', [1]) +@pytest.mark.parametrize('k', [1]) +@torch.inference_mode() +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): + """Verify that speculative tokens are disabled when the batch size + exceeds the threshold. + """ + disable_by_batch_size = 3 + + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=rejection_sampler, + metrics_collector=metrics_collector, + disable_by_batch_size=disable_by_batch_size) + + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + running_queue_size=queue_size) + + if queue_size > disable_by_batch_size: + with patch.object(worker, + '_run_no_spec', + side_effect=ValueError(exception_secret)), \ + pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + + # When the batch size is larger than the threshold, + # we expect no speculative tokens (0). + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 + assert seq_group_metadata_list[ + 0].num_speculative_tokens == expected_num_spec_tokens + + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + + proposer = Top1Proposer( + worker=draft_worker, + device='cpu', # not used + vocab_size=100, # not used + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, + ) + + if queue_size < disable_by_batch_size: + # Should raise exception when executing the mocked draft model. + with pytest.raises(ValueError, match=exception_secret): + proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + else: + # Should not execute the draft model because spec decode is disabled + # for all requests. Accordingly, the proposal length should be 0. + proposals = proposer.get_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 36e91672069dc..312878804b86e 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): num_draft_tokens = 0 k = 5 - num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( + max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( num_draft_tokens, k) rej_sampler = MagicMock() @@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): assert (metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens) assert (metrics.system_efficiency == num_emitted_tokens / - num_possible_tokens) + max_num_emitted_tokens) else: assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.system_efficiency) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f4d44108b47c2..cb2de97a4af94 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,13 +5,12 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput -from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer, - MultiStepWorker) +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, - create_execute_model_data, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) @@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int): list(range(block_size * 2)), ] - final_seq_lens = [ + final_prompt_lens = [ len(prompt + output) + num_steps for prompt, output in zip(prompts, prev_output_tokens) ] @@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int): prompts, num_gpu_blocks, block_size, - final_seq_lens, + final_prompt_lens, continuations=prev_output_tokens) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access @@ -103,29 +102,34 @@ def test_same_output_for_single_step(): [6, 7, 8, 9, 10], ] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - multi_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) - - single_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + multi_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - actual_output = multi_step_worker.execute_model_multi_step( - **multi_step_execute_model_data.to_dict(), num_steps=num_steps) + actual_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=multi_step_seq_group), + sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] + single_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), ) + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=single_step_seq_group))[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): random.randint(0, 1000) for _ in range(random.randint(10, 20)) ] for _ in range(10)] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( @@ -189,19 +193,20 @@ def test_same_output_for_multi_step(): worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) continuations = [[1] for _ in prompts] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_seq_lens=final_seq_lens), ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - multi_step_output = multi_step_worker.execute_model_multi_step( - **execute_model_data.to_dict(), num_steps=num_steps) + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -211,16 +216,16 @@ def test_same_output_for_multi_step(): for _ in multi_step_output: - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_seq_lens=final_seq_lens)) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) - single_step_output.append( - worker.execute_model(**execute_model_data.to_dict(), )) + single_step_output.extend( + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): @@ -266,7 +271,7 @@ def test_same_output_for_multi_step(): @torch.inference_mode() def test_draft_proposals_full_speculation_len(): - """Verify DraftModelTop1Proposer correctly handles case where all sequences + """Verify Top1Proposer correctly handles case where all sequences can speculate. """ k = 10 @@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len(): device = 'cuda:0' draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=2048, vocab_size=vocab_size, + max_proposal_len=2048, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(batch_size, vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(batch_size, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint(low=0, high=vocab_size, size=(batch_size, ), device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len(): @torch.inference_mode() def test_draft_proposals_no_speculations(): - """Verify DraftModelTop1Proposer correctly handles case where no sequences + """Verify Top1Proposer correctly handles case where no sequences can speculate. """ k = 10 @@ -325,27 +333,26 @@ def test_draft_proposals_no_speculations(): prompt_len = 10 draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=prompt_len + k - 1, vocab_size=vocab_size, + max_proposal_len=prompt_len + k - 1, ) - execute_model_data, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prompt_len=prompt_len) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) - assert proposals.proposal_token_ids.shape == torch.Size([0, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) + assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) assert proposals.proposal_lens.shape == torch.Size([batch_size]) assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] @@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() def test_draft_proposals_mixed_k(): - """Verify DraftModelTop1Proposer correctly handles case some sequences can + """Verify Top1Proposer correctly handles case some sequences can speculate and some can't. """ k = 10 @@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k(): for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=long_prompt_len + prev_output_token_len + k - 1, vocab_size=vocab_size, + max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(expected_num_proposal_seqs, vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint( low=0, high=vocab_size, @@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True - execute_model_data, _, _ = create_batch( + seq_group_metadata_list, _, _ = create_batch( batch_size, k, prompt_len=prompt_len, prev_output_token_len=prev_output_token_len, ) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py new file mode 100644 index 0000000000000..88b40d1eb4674 --- /dev/null +++ b/tests/spec_decode/test_ngram_worker.py @@ -0,0 +1,207 @@ +import torch + +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_seq_group_metadata_from_prompts, create_worker + + +def test_ngram_algo_correctness_for_single_no_match(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario cannot find any candidate in one single batch + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([1]) + assert proposals.proposal_lens.tolist() == [0] + + +def test_ngram_algo_correctness_for_batches_not_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find some candidate not full in batchs + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + # shall find no candidate as exceed max_proposal_len + [ + 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37, + 38, 31, 32, 33 + ], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([5]) + + # the first sequence has no match so proposal_len should be overwritten to 0 + assert proposals.proposal_lens.tolist( + ) == [0] + [proposal_len for _ in range(3)] + [0] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == -1 + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] + assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] + assert proposals.proposal_token_ids[4][i] == -1 + + +def test_ngram_algo_correctness_for_batches_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find candidate in all batchs + """ + + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window [0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) + + prompts = [ + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([3]) + + assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1] + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 47aff8f575413..ef9d32f73d668 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -1,4 +1,5 @@ import random +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -6,6 +7,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -13,8 +15,7 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, - mock_worker) +from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @@ -31,26 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(execute_model_req=execute_model_req) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 for args, _ in call_args_list: - (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, actual_k) = args - actual_execute_model_data = ExecuteModelData(seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy) - assert actual_execute_model_data == execute_model_data - assert actual_k == k + actual_execute_model_data = args[0] + assert actual_execute_model_data == execute_model_req @pytest.mark.parametrize('k', [1, 2, 6]) @@ -60,8 +57,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -90,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, prompts, prev_output_tokens = create_batch( + seq_group_metadata_list, prompts, prev_output_tokens = create_batch( batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( @@ -98,23 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_probs=proposal_probs, proposal_lens=proposal_lens) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) seen_contexts = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 - for args, kwargs in call_args_list: - target_execute_model_data = ExecuteModelData.from_dict(kwargs) + for _, kwargs in call_args_list: + seq_group_metadata_list = kwargs[ + "execute_model_req"].seq_group_metadata_list - assert len(target_execute_model_data.seq_group_metadata_list) == ( - k + 1) * batch_size - for seq_group_metadata in ( - target_execute_model_data.seq_group_metadata_list): + assert len(seq_group_metadata_list) == (k + 1) * batch_size + for seq_group_metadata in seq_group_metadata_list: for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) @@ -141,8 +139,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -169,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -186,29 +186,36 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert len(rejection_sampler.call_args_list) == 1 - args, _ = rejection_sampler.call_args_list[0] - (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, - actual_proposal_token_ids) = args + _, kwargs = rejection_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) - assert torch.equal(actual_bonus_token_ids, + assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) assert torch.equal( - actual_proposal_scores, + actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) - assert torch.equal(actual_proposal_token_ids, proposal_token_ids) - assert torch.equal(actual_proposal_probs, proposal_probs) + assert torch.equal(actual.draft_token_ids, proposal_token_ids) + assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @@ -220,8 +227,10 @@ def test_correctly_formats_output(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -248,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -265,10 +274,16 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -282,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output - output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) expected_output = create_sampler_output_list( - rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + token_ids=rejection_sampler_output.transpose(0, 1), + probs=[None for _ in range(k + 1)], + logprobs=[None for _ in range(k + 1)]) seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) - for seq_group_metadata in execute_model_data.seq_group_metadata_list + for seq_group_metadata in seq_group_metadata_list ] actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} @@ -320,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int): continue assert actual_by_step[i].output_token == expected_by_step[ i].output_token - assert actual_by_step[i].logprobs == expected_by_step[i].logprobs @pytest.mark.parametrize('k', [1, 2]) @@ -332,8 +349,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -360,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -377,10 +396,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -399,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): metrics_collector.maybe_collect_rejsample_metrics.return_value = ( mock_rejsample_metrics) - output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -423,6 +449,8 @@ def test_k_equals_zero(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -431,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize('k', [0, 5]) @@ -462,6 +490,8 @@ def test_empty_input_batch(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -470,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.skip_global_cleanup @@ -492,8 +522,8 @@ def test_init_device(): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4637826f254d6..d52b22c30bd43 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, fields from itertools import count from typing import Dict, Iterable, List, Optional, Union from unittest.mock import MagicMock @@ -8,66 +7,29 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SamplerOutput, SequenceData, - SequenceGroupMetadata, SequenceGroupOutput, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker -@dataclass -class ExecuteModelData: - """Helper data structure which facilitates cleaner tests. - """ - seq_group_metadata_list: List[SequenceGroupMetadata] - blocks_to_swap_in: Dict[int, int] - blocks_to_swap_out: Dict[int, int] - blocks_to_copy: Dict[int, List[int]] - - def to_dict(self): - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) - - @classmethod - def from_dict(cls, d): - cleaned = dict((field.name, d[field.name]) for field in fields(cls)) - return cls(**cleaned) - - def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size -def create_execute_model_data( - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, int]] = None, -) -> ExecuteModelData: - if blocks_to_swap_in is None: - blocks_to_swap_in = {} - if blocks_to_swap_out is None: - blocks_to_swap_out = {} - if blocks_to_copy is None: - blocks_to_copy = {} - - return ExecuteModelData( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, - rank: int = 0) -> MagicMock: + rank: int = 0, + use_spec: bool = True) -> MagicMock: if cls is None: cls = Worker - worker = MagicMock(spec=cls) + spec = cls if use_spec else None + + worker = MagicMock(spec=spec) worker.vocab_size = vocab_size worker.max_model_len = max_model_len worker.rank = rank @@ -118,6 +80,7 @@ def create_worker(cls: type, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -140,7 +103,7 @@ def create_seq_group_metadata_from_prompts( prompts: List[List[int]], num_gpu_blocks: int, block_size: int, - final_seq_lens: List[int], + final_prompt_lens: List[int], continuations: Optional[List[List[int]]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: @@ -158,7 +121,7 @@ def create_seq_group_metadata_from_prompts( free_gpu_blocks.pop() for _ in range(round_up_to_next_block(final_len, block_size)) ] - for i, final_len in enumerate(final_seq_lens) + for i, final_len in enumerate(final_prompt_lens) } return [ @@ -197,6 +160,7 @@ def assert_logprobs_dict_allclose( def create_sampler_output_list( token_ids: torch.Tensor, probs: Iterable[Optional[torch.Tensor]], + logprobs: Iterable[Optional[torch.Tensor]], seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: num_steps, batch_size = token_ids.shape token_ids_by_step = token_ids.tolist() @@ -206,18 +170,19 @@ def create_sampler_output_list( return [ SamplerOutput(outputs=[ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( output_token=token_id, parent_seq_id=seq_ids[seq_index], - logprobs={token_id: 0}, + logprobs={token_id: Logprob(0)}, ) ], prompt_logprobs=None, ) for seq_index, token_id in enumerate(token_ids_by_step[step]) ], sampled_token_probs=probs[step], + logprobs=logprobs[step], sampled_token_ids=token_ids[step]) for step in range(num_steps) ] @@ -247,13 +212,12 @@ def create_batch(batch_size, prev_output_tokens = [[ next(iterator) for _ in range(prev_output_token_len) ] for _ in range(batch_size)] - final_seq_lens = [ + final_prompt_lens = [ len(prompt) + len(prev_output_token) + k + 1 for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_seq_lens, - prev_output_tokens, seq_ids), ) - return execute_model_data, prompts, prev_output_tokens + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) + return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/tensorizer_loader/__init__.py b/tests/tensorizer_loader/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py new file mode 100644 index 0000000000000..1579d53a7fe29 --- /dev/null +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -0,0 +1,279 @@ +import gc +import json +import os +import subprocess +from unittest.mock import MagicMock, patch + +import openai +import pytest +import ray +import torch + +from vllm import SamplingParams +# yapf: disable +from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + serialize_vllm_model) + +from ..utils import ServerRunner + +# yapf conflicts with isort for this docstring + + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +model_ref = "facebook/opt-125m" +tensorize_model_for_testing_script = os.path.join( + os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") + + +def is_curl_installed(): + try: + subprocess.check_call(['curl', '--version']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +@pytest.fixture(autouse=True) +def tensorizer_config(): + config = TensorizerConfig(tensorizer_uri="vllm") + return config + + +@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') +def test_load_with_tensorizer(mock_agent, tensorizer_config): + mock_linear_method = MagicMock() + mock_agent_instance = mock_agent.return_value + mock_agent_instance.deserialize.return_value = MagicMock() + + result = load_with_tensorizer(tensorizer_config, + quant_method=mock_linear_method) + + mock_agent.assert_called_once_with(tensorizer_config, + quant_method=mock_linear_method) + mock_agent_instance.deserialize.assert_called_once() + assert result == mock_agent_instance.deserialize.return_value + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_can_deserialize_s3(vllm_runner): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + loaded_hf_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + s3_endpoint="object.ord1.coreweave.com", + )) + + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + + assert deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_deserialized_encrypted_vllm_model_has_same_outputs( + vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) + + config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) + serialize_vllm_model(vllm_model.model.llm_engine, + config_for_serializing, + encryption_key_path=key_path) + + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, + encryption_keyfile=key_path) + + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing) + + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + assert outputs == deserialized_outputs + + +def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, + tmp_path): + hf_model = hf_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + del hf_model + gc.collect() + torch.cuda.empty_cache() + loaded_hf_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + )) + + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) + + assert outputs == deserialized_outputs + + +def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): + from huggingface_hub import snapshot_download + + from examples.multilora_inference import (create_test_prompts, + process_requests) + + model_ref = "meta-llama/Llama-2-7b-hf" + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + + # Serialize model before deserializing and binding LoRA adapters + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) + + del vllm_model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + ), + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=50, + max_model_len=1000, + ) + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + + assert loaded_vllm_model + + +def test_load_without_tensorizer_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner( + model_ref, + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): + ## Serialize model + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) + + model_loader_extra_config = { + "tensorizer_uri": str(model_path), + } + + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + ## Start OpenAI API server + openai_args = [ + "--model", model_ref, "--dtype", "float16", "--load-format", + "tensorizer", "--model-loader-extra-config", + json.dumps(model_loader_extra_config), "--port", "8000" + ] + + server = ServerRunner.remote(openai_args) + + assert ray.get(server.ready.remote()) + print("Server ready.") + + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + completion = client.completions.create(model=model_ref, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + +def test_raise_value_error_on_invalid_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner( + model_ref, + load_format="safetensors", + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + + +def test_tensorizer_with_tp(vllm_runner): + with pytest.raises(ValueError): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + s3_endpoint="object.ord1.coreweave.com", + ), + tensor_parallel_size=2, + ) + + +def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path)) + + vllm_model = vllm_runner(model_ref) + outputs = vllm_model.generate(prompts, sampling_params) + serialize_vllm_model(vllm_model.model.llm_engine, config) + + assert is_vllm_tensorized(config) + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=config) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + assert outputs == deserialized_outputs diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f56..0fbe3dae1ff08 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,14 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/test_config.py b/tests/test_config.py index 13a9f76212679..7cbdaeca9c4d4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,29 @@ +import pytest + from vllm.config import ModelConfig +MODEL_IDS_EXPECTED = [ + ("Qwen/Qwen1.5-7B", 32768), + ("mistralai/Mistral-7B-v0.1", 4096), + ("mistralai/Mistral-7B-Instruct-v0.2", 32768), +] + + +@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED) +def test_disable_sliding_window(model_id_expected): + model_id, expected = model_id_expected + model_config = ModelConfig( + model_id, + model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + disable_sliding_window=True, + ) + assert model_config.max_model_len == expected + def test_get_sliding_window(): TEST_SLIDING_WINDOW = 4096 @@ -11,8 +35,6 @@ def test_get_sliding_window(): "Qwen/Qwen1.5-7B", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -30,8 +52,6 @@ def test_get_sliding_window(): "mistralai/Mistral-7B-v0.1", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -40,4 +60,58 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + +def test_rope_scaling(): + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert llama_model_config.max_model_len == 8192 + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert llama_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == LONGCHAT_ROPE_SCALING + assert longchat_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 4096 diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000000000..887c7101decda --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000000000..74f1125fb37c9 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,214 @@ +import json +import logging +import os +import sys +import tempfile +from json.decoder import JSONDecodeError +from tempfile import NamedTemporaryFile +from typing import Any +from unittest.mock import patch +from uuid import uuid4 + +import pytest + +from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, + enable_trace_function_call, init_logger) +from vllm.logging import NewLineFormatter + + +def f1(x): + return f2(x) + + +def f2(x): + return x + + +def test_trace_function_call(): + fd, path = tempfile.mkstemp() + cur_dir = os.path.dirname(__file__) + enable_trace_function_call(path, cur_dir) + f1(1) + with open(path, 'r') as f: + content = f.read() + + assert "f1" in content + assert "f2" in content + sys.settrace(None) + os.remove(path) + + +def test_default_vllm_root_logger_configuration(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + logger = logging.getLogger("vllm") + assert logger.level == logging.DEBUG + assert not logger.propagate + + handler = logger.handlers[0] + assert handler.stream == sys.stdout + assert handler.level == logging.INFO + + formatter = handler.formatter + assert formatter is not None + assert isinstance(formatter, NewLineFormatter) + assert formatter._fmt == _FORMAT + assert formatter.datefmt == _DATE_FORMAT + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + root_logger = logging.getLogger("vllm") + root_handler = root_logger.handlers[0] + + unique_name = f"vllm.{uuid4()}" + logger = init_logger(unique_name) + assert logger.name == unique_name + assert logger.level == logging.NOTSET + assert not logger.handlers + assert logger.propagate + + message = "Hello, world!" + with patch.object(root_handler, "emit") as root_handle_mock: + logger.info(message) + + root_handle_mock.assert_called_once() + _, call_args, _ = root_handle_mock.mock_calls[0] + log_record = call_args[0] + assert unique_name == log_record.name + assert message == log_record.msg + assert message == log_record.msg + assert log_record.levelno == logging.INFO + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_logger_configuring_can_be_disabled(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + + with patch("logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + dict_config_mock.assert_not_called() + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch( + "vllm.logger.VLLM_LOGGING_CONFIG_PATH", + "/if/there/is/a/file/here/then/you/did/this/to/yourself.json", +) +def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == RuntimeError + assert "File does not exist" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write("---\nloggers: []\nversion: 1") + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(JSONDecodeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == JSONDecodeError + assert "Expecting value" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@pytest.mark.parametrize("unexpected_config", ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, +)) +def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( + unexpected_config: Any): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(unexpected_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(ValueError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == ValueError + assert "Invalid logging config. Expected Dict, got" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_custom_logging_config_is_parsed_and_used_when_provided(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + "propagate": False, + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name), patch( + "logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + assert dict_config_mock.called_with(valid_logging_config) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type is RuntimeError + expected_message_snippet = ( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given.") + assert expected_message_snippet in str(ex_info) + + # Remember! The root logger is assumed to have been configured as + # though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None. + root_logger = logging.getLogger("vllm") + other_logger_name = f"vllm.test_logger.{uuid4()}" + other_logger = init_logger(other_logger_name) + assert other_logger.handlers != root_logger.handlers + assert other_logger.level != root_logger.level + assert other_logger.propagate diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index fe321520114f7..4ee980505a3ab 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -6,9 +6,10 @@ import torch from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -29,16 +30,15 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(None, None, None, None, None) - return input_tensor, fake_logits, logits_processor, model_runner + return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) @@ -53,8 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -64,7 +63,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -75,11 +74,14 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, @@ -90,5 +92,3 @@ def pick_ith(token_ids, logits): fake_logits *= logits_processor.scale assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) - - del model_runner diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..3136402518b9f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,45 +1,18 @@ -import time -from typing import Optional - import pytest -from vllm import SamplingParams -from vllm.lora.request import LoRARequest -from vllm.sequence import (SamplerOutput, Sequence, SequenceData, - SequenceGroup, SequenceGroupOutput, SequenceOutput) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, - best_of: int = 1, -) -> SequenceGroup: - if not block_size: - block_size = prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) +from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, + SequenceData, SequenceOutput) - return seq_group +from .core.utils import create_dummy_prompt @pytest.fixture def sample_outputs(): return [ - SequenceGroupOutput(samples=[ + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) ], - prompt_logprobs=None) for i in range(5) + prompt_logprobs=None) for i in range(5) ] @@ -60,10 +33,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs): def test_sampler_output_setitem(sampler_output): - new_output = SequenceGroupOutput(samples=[ + new_output = CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) ], - prompt_logprobs=None) + prompt_logprobs=None) sampler_output[2] = new_output assert sampler_output[2] == new_output @@ -102,7 +75,7 @@ def test_sequence_data_prefill(): def test_sequence_group_stage(): - seq_group = create_dummy_prompt("1", 12) + _, seq_group = create_dummy_prompt("1", 12) assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(6) assert seq_group.is_prefill() is True diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py new file mode 100644 index 0000000000000..8540e98da366a --- /dev/null +++ b/tests/test_sharded_state_loader.py @@ -0,0 +1,90 @@ +import os +import shutil +from tempfile import TemporaryDirectory + +import pytest +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.model_executor.model_loader.loader import ShardedStateLoader + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + seed=0, + max_tokens=256, + ignore_eos=True, +) + + +def test_filter_subtensors(): + state_dict = { + "a": torch.empty(2), + "b": torch.empty((2, 4)), + "c": torch.empty((2, 4, 8)), + } + state_dict.update({ + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + }) + filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) + assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") + for key, tensor in filtered_state_dict.items(): + assert tensor.equal(state_dict[key]) + + +@pytest.mark.parametrize("enable_lora", [False, True]) +def test_sharded_state_loader(enable_lora): + weights_patterns = ("*.bin", "*.pt", "*.safetensors") + + with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: + input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", + cache_dir=cache_dir) + + llm = LLM( + model=input_dir, + worker_use_ray=True, + gpu_memory_utilization=0.3, + ) + + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir) + del llm.llm_engine.model_executor + + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + del llm_before.llm_engine.model_executor + + llm_after = LLM( + model=output_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + load_format="sharded_state", + ) + gen_after = llm_after.generate(prompts, sampling_params) + out_after = [gen.outputs[0].__dict__ for gen in gen_after] + del llm_after.llm_engine.model_executor + + assert out_before == out_after diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..a6c3896fa43bf --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,118 @@ +import asyncio +import sys +from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, + Tuple, TypeVar) + +import pytest + +from vllm.utils import deprecate_kwargs, merge_async_iterators + +from .utils import error_on_warning + +if sys.version_info < (3, 10): + if TYPE_CHECKING: + _AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any]) + _AwaitableT_co = TypeVar("_AwaitableT_co", + bound=Awaitable[Any], + covariant=True) + + class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]): + + def __anext__(self) -> _AwaitableT_co: + ... + + def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT": + return i.__anext__() + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e + + +def test_deprecate_kwargs_always(): + + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f1..8d019fe5f38ca 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,10 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + }, block_size=16, ) diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py new file mode 100644 index 0000000000000..8db7204f15d4e --- /dev/null +++ b/tests/tokenization/test_tokenizer.py @@ -0,0 +1,20 @@ +import pytest +from transformers import PreTrainedTokenizerBase + +from vllm.transformers_utils.tokenizer import get_tokenizer + +TOKENIZER_NAMES = [ + "facebook/opt-125m", + "gpt2", +] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +def test_tokenizer_revision(tokenizer_name: str): + # Assume that "main" branch always exists + tokenizer = get_tokenizer(tokenizer_name, revision="main") + assert isinstance(tokenizer, PreTrainedTokenizerBase) + + # Assume that "never" branch always does not exist + with pytest.raises(OSError, match='not a valid git identifier'): + get_tokenizer(tokenizer_name, revision="never") diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000..329842911e159 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,103 @@ +import os +import subprocess +import sys +import time +import warnings +from contextlib import contextmanager + +import ray +import requests + +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils import get_open_port + +# Path to root of repository so that utilities can be imported by ray workers +VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) + + +@ray.remote(num_gpus=1) +class ServerRunner: + MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > self.MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +def init_test_distributed_environment( + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, + local_rank: int = -1, +) -> None: + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=pp_size * tp_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=local_rank) + ensure_model_parallel_initialized(tp_size, pp_size) + + +def multi_process_tensor_parallel( + tp_size: int, + pp_size: int, + test_target, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init(runtime_env={"working_dir": VLLM_PATH}) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(tp_size * pp_size): + refs.append( + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa7..92de545acd53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,24 +1,47 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.distributed.parallel_state import init_distributed_environment +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) - model_runner.set_block_size(16) + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -31,49 +54,54 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_prompt_lens == prompt_lens + seq_len - 1) + selected_token_start_idx += seq_len + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.prompt_lens_tensor, - torch.tensor(prompt_lens, device=device)) - assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 - assert attn_metadata.max_prompt_len == max(prompt_lens) + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( - attn_metadata.subquery_start_loc, + attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is - # equivalent to subquery_start_loc. + # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len seq_start_loc.append(start_idx) assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert attn_metadata.max_context_len is None assert torch.allclose( - attn_metadata.context_lens, - torch.zeros(attn_metadata.context_lens.shape[0], + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device)) @@ -83,23 +111,25 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + torch.allclose(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -110,29 +140,28 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_decode_cuda_graph(batch_size): - model_config = ModelConfig( + model_runner = _create_model_runner( "facebook/opt-125m", - "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", - revision=None, enforce_eager=False, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) - model_runner.set_block_size(16) - prompt_lens = [] + context_lens = [] seq_group_metadata_list = [] + # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + context_len = i % (model_runner.block_size - 1) + 1 + context_lens.append(context_len) + seq_data = list(range(context_len)) seq_data = SequenceData(seq_data) + seq_data.update_num_computed_tokens(context_len) + # Append one token ID since prefill is finished. + seq_data.append_token_id(1, 0) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -143,23 +172,48 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, model_input.input_positions, + model_input.attn_metadata, model_input.slot_mapping) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is False - assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs - assert attn_metadata.max_prompt_len is None - assert attn_metadata.subquery_start_loc is None - assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + seq_lens = [context_len + 1 for context_len in context_lens] + # seq_lens are padded to expected_bs + for _ in range(expected_bs - len(seq_lens)): + seq_lens.append(1) + assert attn_metadata.seq_lens == seq_lens + start_idx = 0 + start_loc = [start_idx] + for _ in context_lens: + # decode has only 1 token for query. + start_idx += 1 + start_loc.append(start_idx) assert torch.allclose( - attn_metadata.context_lens[:len(prompt_lens)], - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor(context_lens, dtype=torch.int, device=device)) + assert attn_metadata.max_decode_seq_len == max(seq_lens) + assert torch.allclose( + attn_metadata.seq_lens_tensor[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -168,25 +222,154 @@ def test_prepare_decode_cuda_graph(batch_size): # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) - # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + torch.allclose(input_tokens, input_positions) # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for _ in context_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query lens is all 1 for decode. + query_lens=[1 for _ in range(len(context_lens))], + device=model_runner.device, + pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_runner = _create_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + enforce_eager=False, + ) + seq_group_metadata_list = [] + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_seq_lens) == 0 + + +@pytest.fixture +def distributed_init(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", + local_rank=0) + + +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, distributed_init): + model_runner = _create_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, + ) + + # Add prefill requests. + seq_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + context_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(context_len)) + seq_data = SequenceData(prompt_toks) + seq_data.append_token_id(1, 0) + seq_data.update_num_computed_tokens(context_len) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.num_prefills == prefill_batch_size + assert attn_metadata.num_decode_tokens == decode_batch_size + assert attn_metadata.num_prefill_tokens == sum(seq_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + attn_metadata = model_runner._prepare_model_input( + seq_group_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 8edb1cf05c08e..d941ffdb5588a 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -1,6 +1,7 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.worker import Worker @@ -23,6 +24,7 @@ def test_swap() -> None: scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -52,27 +54,36 @@ def test_swap() -> None: a.cuda(), b.cuda(), rtol=0.0, atol=0.0) # Test swap out. - blocks_to_swap_out = {3: 72, 56: 35, 84: 34} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in={}, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}) + blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in=[], + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=[], + ) + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out.items(): + for src, dst in blocks_to_swap_out: assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out={}, - blocks_to_copy={}) + execute_model_req.blocks_to_swap_out = [] + execute_model_req.blocks_to_swap_in = [ + (19, 45), + (67, 23), + (12, 78), + (40, 99), + (1, 71), + ] + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in: assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/__init__.py b/vllm/__init__.py index 2c1fd40573240..dc59bf4a81931 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -3,23 +3,32 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, RequestOutput) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -__version__ = "0.4.0.post1" +__version__ = "0.4.3" __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", "initialize_ray_cluster", + "PoolingParams", ] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000000000..fe1c4eff698d3 --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,335 @@ +from typing import Optional, Tuple, Type + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + vllm_ops.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + vllm_ops.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, num_bits: int, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, num_bits, size_m, size_n, + size_k) + + +# cutlass +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + + return out + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor) -> torch.Tensor: + return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) + + +def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, size_k: int, + is_k_full: bool) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, num_bits, size_m, size_n, + size_k, is_k_full) + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + vllm_ops.static_scaled_fp8_quant(output, input, scale) + return output, scale + + +# int8 +def static_scaled_int8_quant(input: torch.Tensor, + scale: float) -> torch.Tensor: + """ + Quantize the input tensor to int8 and return the quantized tensor. + + Args: + input: The input tensor to be quantized to int8. + scale: Scaling factor for the int8 quantization. + + Returns: + torch.Tensor: Output tensor in int8. + """ + q = torch.empty_like(input, dtype=torch.int8) + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +) -> None: + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0) -> None: + vllm_ops.convert_fp8(output, input, torch.Tensor([scale])) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c2..f6bce9a187c64 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,6 +4,7 @@ from vllm.attention.selector import get_attn_backend __all__ = [ + "Attention", "AttentionBackend", "AttentionMetadata", "Attention", diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6fa..6396103bf5efa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) import torch @@ -8,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: @@ -33,7 +39,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError @@ -41,25 +47,59 @@ def swap_blocks( @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError @dataclass class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass - def asdict_zerocopy(self) -> Dict[str, Any]: + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) - for field in fields(self) + for field in fields(self) if field.name not in skip_fields } -class AttentionImpl(ABC): +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionImpl(ABC, Generic[T]): @abstractmethod def __init__( @@ -70,6 +110,8 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: raise NotImplementedError @@ -80,7 +122,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - kv_scale: float, + attn_metadata: T, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000000000..dce2b83615b7a --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,410 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata": + return BlocksparseFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + kv_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache is None \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index a34e99e0a1575..0b9d6283493f2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,24 +1,25 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.utils import is_hip class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @@ -34,27 +35,36 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -62,64 +72,142 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| - - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -127,28 +215,39 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") def forward( self, @@ -157,19 +256,22 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -177,84 +279,89 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if is_hip(): - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - ) - else: - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - output = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, - self.alibi_slopes, + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_scale, - ) + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/flash_attn_triton.py deleted file mode 100644 index 3a136d0a2476b..0000000000000 --- a/vllm/attention/backends/flash_attn_triton.py +++ /dev/null @@ -1,176 +0,0 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" -from typing import List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionMetadata) -from vllm.attention.ops.flash_attention_triton import triton_attention -from vllm.attention.ops.paged_attn import PagedAttention - - -class FlashAttentionTritonBackend(FlashAttentionBackend): - - @staticmethod - def get_impl_cls() -> Type["FlashAttentionTritonImpl"]: - return FlashAttentionTritonImpl - - -class FlashAttentionTritonImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - NOTE: This code is mostly duplicate from flash_attn backend - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return (x[:, :, - None, :].expand(tokens, n_kv_heads, n_rep, - head_dim).reshape(tokens, n_kv_heads * n_rep, - head_dim)) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - ) - - if attn_metadata.is_prompt: - # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - - if self.num_kv_heads != self.num_heads: - # Interleave for MQA workaround. - key = self.repeat_kv(key, self.num_queries_per_kv) - value = self.repeat_kv(value, self.num_queries_per_kv) - - output, _ = triton_attention( - query, - key, - value, - None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, - True, - self.scale, - ) - else: - # prefix-enabled attention - output = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, - self.alibi_slopes, - ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000000000..7b7959d257fac --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,250 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +import flashinfer +import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + + use_cuda_graph: bool = False + + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage since we still + # use flash attention for prefill. + seq_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # Metadata for the decode stage + # Workspace buffer required by the kernel, the buffer should not + # be allocated/deacollated by the FalshInfermetadata object. + workspace_buffer: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # When using flashinfer, we are also creating the FlashInferMetadata, + # which will also call post_init by default, here we want to skip the + # post_init if it's the prefill phase. + if self.num_prefills == 0: + assert self.num_decode_tokens > 0 + self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=self.data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: FlashInferMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + assert kv_scale == 1.0 + num_tokens, hidden_size = query.shape + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + + if kv_cache is not None: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + assert prefill_meta.block_tables is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + raise NotImplementedError( + "Prefix caching is not supported with flashinfer yet.") + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=self.scale, + ) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8acd401833cc9..e92e6c5e2dc8d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,10 +1,10 @@ """Attention layer ROCm GPUs.""" -import os from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, @@ -16,6 +16,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @@ -38,14 +42,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -59,38 +63,32 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| - - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -100,62 +98,72 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[torch.Tensor]: - attn_biases = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - - return attn_biases - - -def _make_alibi_bias_v2( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], - make_attn_mask: bool = True -) -> List[torch.Tensor]: - attn_biases = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) - - return attn_biases + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata class ROCmFlashAttentionImpl(AttentionImpl): @@ -173,6 +181,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -180,48 +197,58 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "ROCFlashAttention does not support blocksparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") - self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = (os.environ.get( - "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) - if self.use_naive_attn: - # AMD Radeon 7900 series (gfx1100) currently does not support - # xFormers nor FlashAttention. As a temporary workaround, we use - # naive PyTorch implementation of attention. - self.attn_func = _naive_attention() - logger.debug("Using naive attention in ROCmBackend") - elif self.use_triton_flash_attn: + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if torch.cuda.get_device_capability()[0] != 9: + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + self.attn_func = _naive_attention + logger.debug("Using naive attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -270,87 +297,101 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + assert prefill_meta.seq_lens is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_naive_attn or self.use_triton_flash_attn: + if self.use_triton_flash_attn: + out, _ = self.attn_func( + query, + key, + value, + None, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, + True, + self.scale, + ) + elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - att_masks = None - if self.alibi_slopes is not None: - att_masks = _make_alibi_bias_v2( - self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens, make_attn_mask=False) # type: ignore - if self.use_naive_attn: - output = self.attn_func( - query, - key, - value, - attn_metadata.prompt_lens, - self.scale, - att_masks - ) - else: - output, _ = self.attn_func( - query, - key, - value, - None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, - True, - self.scale, - att_masks[0][None] if att_masks is not None else None, - ) + out = self.attn_func( + query, + key, + value, + prefill_meta.seq_lens, + self.scale, + ) else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, + self.sliding_window[0], ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, @@ -365,31 +406,24 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - prompt_lens: List[int], + seq_lens: List[int], scale: float, - attn_masks: Optional[List[torch.Tensor]], ) -> torch.Tensor: - num_tokens = query.shape[0] output = torch.empty_like(query) start = 0 - for i, prompt_len in enumerate(prompt_lens): - end = start + prompt_len + for _, seq_len in enumerate(seq_lens): + end = start + seq_len out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], + query[start:end], + key[start:end], + value[start:end], scale, - attn_masks[i], ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += prompt_len + start += seq_len - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, -1) + return output def _naive_masked_attention( @@ -397,17 +431,14 @@ def _naive_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - attn_mask: Optional[torch.Tensor], ) -> torch.Tensor: - seq_len, _, _ = query.shape - if attn_mask is None: - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - + seq_len, head_size, head_dim = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb79..9b50adec5244d 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -14,6 +14,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @@ -36,14 +40,14 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -56,16 +60,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] - prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int - - max_subquery_len: Optional[int] = None - max_prompt_len: Optional[int] = None - subquery_start_loc: Optional[torch.Tensor] = None - seq_start_loc: Optional[torch.Tensor] = None - use_cuda_graph: bool = False + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -75,37 +70,64 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None -class TorchSDPABackendImpl(AttentionImpl): + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "Torch SPDA does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -113,8 +135,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, - kv_scale: float, + attn_metadata: TorchSDPAMetadata, # type: ignore + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -127,6 +149,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -139,10 +162,10 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -153,13 +176,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -170,9 +193,9 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): - end = start + prompt_len + for seq_len, mask in zip(attn_metadata.seq_lens, + attn_metadata.attn_bias): + end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -195,9 +218,9 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, + attn_metadata.seq_lens_tensor, + attn_metadata.max_decode_seq_len, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, @@ -211,23 +234,23 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -235,14 +258,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - prompt_lens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: + for seq_len in seq_lens: tensor = torch.full( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 05b68bba5e6eb..99a3e88bc07b6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,6 +1,6 @@ """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -19,6 +19,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @@ -48,7 +52,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -62,54 +66,47 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -119,22 +116,91 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - -class XFormersImpl(AttentionImpl): + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -142,18 +208,23 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "XFormer does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -170,8 +241,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, - kv_scale: float, + attn_metadata: "XFormersMetadata", + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -184,7 +255,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -199,63 +269,65 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens - if attn_metadata.is_prompt: + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, + self.sliding_window, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, @@ -275,20 +347,38 @@ def _run_memory_efficient_xformers_forward( """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + assert attn_metadata.seq_lens is not None + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.prompt_lens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -296,12 +386,13 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.prompt_lens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -312,17 +403,16 @@ def _run_memory_efficient_xformers_forward( attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 - for i, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -331,8 +421,8 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) - start += prompt_len + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len return output @@ -340,13 +430,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -354,16 +444,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - prompt_len, + seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) + )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f94..db55a31476fed 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,11 +1,14 @@ """Attention layer.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) class Attention(nn.Module): @@ -27,13 +30,53 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) - impl_cls = self.backend.get_impl_cls() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + sliding_window = cache_config.sliding_window + else: + kv_cache_dtype = "auto" + block_size = 16 + sliding_window = None + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default kv_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized kv_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._kv_scale = 1.0 + quant_method = quant_config.get_quant_method( + self) if quant_config else None + if quant_method is not None: + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # When FP8 quantization is enabled, we make a parameter + # "kv_scale" so that it can be loaded from FP8 checkpoint. + # The kv_scale will then be converted back + # to self._kv_scale in a native float32 value after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size, blocksparse_params + is not None) + impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window) + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params) def forward( self, @@ -42,7 +85,14 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, - kv_scale) + self._kv_scale) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000000000..ec1c37c5bcb0e --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,423 @@ +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000000000..300211e70bb79 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +import math + +import torch + +from vllm.utils import is_cpu, is_hip + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = is_hip() or is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if torch.cuda.is_available() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = [ + self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v] + ] + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) \ No newline at end of file diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000000000..0d90dd971e156 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,216 @@ +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import torch +import triton +from scipy import sparse + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/flash_attention_triton.py deleted file mode 100644 index b86e845020b07..0000000000000 --- a/vllm/attention/ops/flash_attention_triton.py +++ /dev/null @@ -1,809 +0,0 @@ -#!/usr/bin/env python -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team - -Features supported: - -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims - -""" - -import torch -import triton -import triton.language as tl - -torch_dtype: tl.constexpr = torch.float16 - - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) - return acc, l_i, m_i - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], -) -@triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - hq, - hk, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = MAX_SEQLENS_K - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ - * seqlen_q * seqlen_k - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] >= - out_mask_boundary[None, :]) - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): - if o is None: - o = torch.empty_like(q, dtype=v.dtype) - - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None - else: - padded_d_model = head_size - - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) - - encoded_softmax = None - - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) - else: - bias_strides = (0, 0, 0, 0) - - attn_fwd[grid]( - q, - k, - v, - bias, - sm_scale, - None, - o, - *q_strides, - *k_strides, - *v_strides, - *o_strides, - *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, - philox_seed=philox_seed, - philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, - BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) - - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False - return o, encoded_softmax - - -triton_attention = _attention.apply diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 0364213753431..0d3ee47193306 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,10 +1,10 @@ import os from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.utils import is_hip @@ -21,17 +21,11 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -39,14 +33,13 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( @@ -83,7 +76,7 @@ def write_to_paged_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -99,14 +92,27 @@ def forward_decode( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, + seq_lens: torch.Tensor, + max_seq_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape @@ -115,12 +121,12 @@ def forward_decode( and head_size == 128 and block_size == 16 and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_context_len <= 32768) + and max_seq_len <= 32768) if not use_custom: _PARTITION_SIZE = _PARTITION_SIZE_V1V2 else: _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM - max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -129,7 +135,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_context_len <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) and not use_custom) if use_v1: @@ -142,12 +148,17 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) else: # Run PagedAttention V2 or PagedAttention Custom. @@ -175,12 +186,17 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) else: paged_attention_custom( @@ -194,9 +210,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, ) @@ -210,11 +226,12 @@ def forward_prefix( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, - max_subquery_len: int, + max_query_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -225,12 +242,13 @@ def forward_prefix( key_cache, value_cache, block_tables, - # subquery_start_loc is (batch_size + 1,) - subquery_start_loc[:-1], - prompt_lens_tensor, + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], + seq_lens_tensor, context_lens, - max_subquery_len, + max_query_len, alibi_slopes, + sliding_window, ) return output @@ -238,21 +256,21 @@ def forward_prefix( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70f09224f1cf6..b99cf9a50d105 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -47,8 +47,10 @@ def _fwd_kernel( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -59,64 +61,96 @@ def _fwd_kernel( cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + # start position inside of the query + # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m # initialize offsets + # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0) # [N] + # [D,N] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] off_v = ( bn[:, None] * stride_v_cache_bs + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk += tl.dot(q, k) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + # -- update output accumulator -- # scale p p_scale = beta / l_i_new @@ -126,8 +160,9 @@ def _fwd_kernel( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] p = p.to(v.dtype) acc += tl.dot(p, v) @@ -142,23 +177,29 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + # compute query against itself (with causal mask) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -179,8 +220,8 @@ def _fwd_kernel( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), other=0.0) p = p.to(v.dtype) @@ -195,7 +236,8 @@ def _fwd_kernel( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) return @triton.jit @@ -430,7 +472,8 @@ def _fwd_kernel_alibi( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): # attn_bias[] @@ -451,21 +494,24 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) alibi_start_q = tl.arange( @@ -490,8 +536,9 @@ def _fwd_kernel_alibi( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -525,7 +572,8 @@ def _fwd_kernel_alibi( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -558,8 +606,9 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -595,8 +644,9 @@ def _fwd_kernel_alibi( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -614,7 +664,8 @@ def _fwd_kernel_alibi( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() @@ -629,14 +680,16 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, - alibi_slopes=None): + alibi_slopes=None, + sliding_window=None): cap = torch.cuda.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -644,6 +697,10 @@ def context_attention_fwd(q, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: _fwd_kernel_alibi[grid]( @@ -690,6 +747,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, @@ -738,7 +796,9 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index c99029175b5a2..f94211116a746 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,55 +61,81 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) else: - tensor = tl.load(ptrs) + tensor = tl.load(block_ptr) return tensor @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr): +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) if PRE_LOAD_V: - # We can use the same offsets as k, just with dims transposed. - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: + if MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -119,15 +145,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None - bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) - # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an additional + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += (bias * 1.44269504089) - - # softmax + qk += bias * 1.44269504089 m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -135,29 +159,51 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(v.type.element_ty), v) - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if bias_ptrs is not None: - bias_ptrs += BLOCK_N * stride_bn + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += BLOCK_N + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) return acc, l_i, m_i @@ -260,20 +306,60 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -298,82 +384,120 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): + if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - tl.store(o_ptrs, acc, mask=o_ptrs_mask) - # The tensor allocated for L is based on MAX_SEQLENS_Q as that is - # statically known. - l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # We store inf to LSE, not -inf because in the bwd pass, we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - l_ptrs_mask = offs_m < MAX_SEQLENS_Q - tl.store(l_ptrs, l, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be handled here too? + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? return # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL # Compute pointers for all the tensors used in this kernel. - q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn - v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn - if USE_BIAS: - # Note: this might get large enough to overflow on some configs - bias_offset = off_h_q * stride_bh - bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn - else: - bias_ptrs = None - - if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) else: - alibi_slope = None - + bias_ptr = None if ENABLE_DROPOUT: - off_hz = off_z * HQ + off_h_q - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. In - # this case, we return an invalid pointer so indicate the mask is not valid. + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: - encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k - encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) else: - encoded_sm_ptrs = None + encoded_softmax_block_ptr = 0 # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -382,11 +506,8 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q_ptrs_mask = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - q = (q * qk_scale).to(q.type.element_ty) + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -398,50 +519,96 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual + # Compute for full blocks. Here we set causal to false regardless of its # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vk - if USE_BIAS: - bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -454,34 +621,45 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: + if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve - overflow_size = end_m_idx - seqlen_q - if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - else: - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) - if overflow_size > 0: - o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) def check_args( @@ -594,8 +772,6 @@ def forward( ) else: bias_strides = (0, 0, 0, 0) - alibi_strides = (0, 0) - M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) attn_fwd[grid]( q, @@ -603,21 +779,19 @@ def forward( v, bias, sm_scale, - M, + None, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - *alibi_strides, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - alibi_slopes=None, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, @@ -626,8 +800,7 @@ def forward( IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, - USE_BIAS=bias is not None, - USE_ALIBI=False, + BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d49..9ceda3431b898 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,9 +1,10 @@ import enum from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip @@ -16,17 +17,37 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_blocksparse: bool = False, +) -> Type[AttentionBackend]: + + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend + """Determine which attention backend to use and only import + the selected backend module. + """ + backend = which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif backend == _Backend.XFORMERS: + if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) @@ -40,39 +61,104 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.warning("Eager mode is required for the Flashinfer backend. " + "Please make sure --enforce-eager is set.") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> _Backend: """Returns which flash attention backend to use.""" + + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + backend_members = _Backend.__members__ + if backend_by_env_var not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_by_env_var}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + selected_backend = _Backend[backend_by_env_var] + if is_cpu(): + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA if is_hip(): # AMD GPUs. - if torch.cuda.get_device_capability()[0] != 9: - # not Instinct series GPUs. - logger.info("flash_atten is not supported on NAVI GPUs.") + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH - # NVIDIA GPUs. - if torch.cuda.get_device_capability()[0] < 8: - # Volta and Turing NVIDIA GPUs. - logger.info("Cannot use FlashAttention backend for Volta and Turing " - "GPUs.") - return _Backend.XFORMERS - - if dtype not in (torch.float16, torch.bfloat16): - logger.info("Cannot use FlashAttention backend for dtype other than " - "torch.float16 or torch.bfloat16.") - return _Backend.XFORMERS - - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention backend because the flash_attn package " - "is not found. Please install it for better performance.") - return _Backend.XFORMERS - return _Backend.FLASH_ATTN + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + selected_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + selected_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is installed. + if selected_backend == _Backend.FLASH_ATTN: + try: + import vllm_flash_attn # noqa: F401 + + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + selected_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm_flash_attn package is not found. " + "`pip install vllm-flash-attn` for better performance.") + selected_backend = _Backend.XFORMERS + + return selected_backend diff --git a/vllm/config.py b/vllm/config.py index 5496c892c1638..63471aa5301b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,24 +1,27 @@ import enum import json import os -from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union import torch -from packaging.version import Version from transformers import PretrainedConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, - is_neuron) +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.model_executor.model_loader.loader import BaseModelLoader + logger = init_logger(__name__) _GB = 1 << 30 +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 class ModelConfig: @@ -26,23 +29,13 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -53,6 +46,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -62,15 +58,28 @@ class ModelConfig: weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. """ def __init__( @@ -79,91 +88,62 @@ def __init__( tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, - download_dir: Optional[str], - load_format: str, dtype: Union[str, torch.dtype], seed: int, revision: Optional[str] = None, code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code - self.download_dir = download_dir - self.load_format = load_format self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + if self.max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = (max_seq_len_to_capture + or max_context_len_to_capture) self.max_logprobs = max_logprobs - - if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - if not os.path.exists(model): - model_path = snapshot_download(model_id=model, - cache_dir=download_dir, - revision=revision) - else: - model_path = model - self.model = model_path - self.download_dir = model_path - self.tokenizer = model_path + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_text_config, - max_model_len) - self._verify_load_format() - self._verify_tokenizer_mode() + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) + self.served_model_name = get_served_model_name(model, + served_model_name) + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() - def _verify_load_format(self) -> None: - load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" - ] - rocm_not_supported_load_format = [] - if load_format not in supported_load_format: - raise ValueError( - f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format '{load_format}' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - - # TODO: Remove this check once HF updates the pt weights of Mixtral. - architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format == "pt": - raise ValueError( - "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") - self.load_format = load_format - def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: @@ -172,29 +152,44 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_embedding_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.embedding_mode = any( + ModelRegistry.is_embedding_model(arch) for arch in architectures) + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # SparseML uses a "compression_config" with a "quantization_config". + compression_cfg = getattr(self.hf_config, "compression_config", + None) + if compression_cfg is not None: + quant_cfg = compression_cfg.get("quantization_config", None) + + return quant_cfg + def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "marlin", "fp8"] - rocm_not_supported_quantization = ["awq", "marlin"] + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] if self.quantization is not None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - quant_cfg = getattr(self.hf_config, "quantization_config", None) + quant_cfg = self._parse_quant_hf_config() + if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" - or quant_cfg.get("is_marlin_format", False)) - - # Use marlin if the GPTQ model is serialized in marlin format. - if quant_method == "gptq" and is_format_marlin: - logger.info("The model is serialized in Marlin format. " - "Using Marlin kernel.") - quant_method = "marlin" - if self.quantization == "gptq": - self.quantization = quant_method + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: @@ -210,21 +205,22 @@ def _verify_quantization(self) -> None: f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") if is_hip( - ) and self.quantization in rocm_not_supported_quantization: + ) and self.quantization not in rocm_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if self.quantization != "marlin": + if (self.quantization + not in ["marlin", "gptq_marlin_24", "gptq_marlin"]): logger.warning( - f"{self.quantization} quantization is not fully " + "%s quantization is not fully " "optimized yet. The speed can be slower than " - "non-quantized models.") + "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) def verify_with_parallel_config( self, @@ -246,7 +242,7 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -258,6 +254,15 @@ def get_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + def get_vocab_size(self) -> int: return self.hf_text_config.vocab_size @@ -320,6 +325,11 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + return self.hf_text_config.num_attention_heads // \ + parallel_config.tensor_parallel_size + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size @@ -357,6 +367,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() + self._verify_prefix_caching() # Will be set after profiling. self.num_gpu_blocks = None @@ -376,23 +387,28 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8": - if not is_hip(): - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version < Version("11.8"): - raise ValueError( - "FP8 is not supported when cuda version is" - "lower than 11.8.") + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " - "But it may cause slight accuracy drop without scaling " - "factors. FP8_E5M2 (without scaling) is only supported on " - "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " - "is instead supported for common inference criteria.") + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -409,13 +425,13 @@ def verify_with_parallel_config( if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: - logger.warning("Possibly too large swap space. " + msg) + logger.warning("Possibly too large swap space. %s", msg) @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -439,9 +455,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -464,15 +480,73 @@ def create_config( return tokenizer_pool_config +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + class ParallelConfig: """Configuration for the distributed execution. Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to - True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + worker_use_ray: Deprecated, use distributed_executor_backend instead. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -482,48 +556,71 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + placement_group: ray distributed model workers placement group. + distributed_executor_backend: Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, + distributed_executor_backend: Optional[str] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size - self.worker_use_ray = worker_use_ray - self.worker_use_torchrun = False + self.distributed_executor_backend = distributed_executor_backend self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group + self.distributed_executor_backend = distributed_executor_backend self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - if is_hip() and not self.worker_use_ray: + if worker_use_ray: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + elif self.distributed_executor_backend != "ray": + raise ValueError(f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'.") + + if self.distributed_executor_backend is None and self.world_size > 1: + if is_hip(): logger.info("Using torchrun for multi-GPU on " - "ROCM platform. Use --worker-use-ray " - "to override") + "ROCM platform. Use --worker-use-ray or " + "--distributed-executor-backend={ray, mp} to " + "override") if not os.environ.get("RANK"): raise RuntimeError( "Needs to be run in torchrun: " "torchrun --standalone --nproc_per_node= ...") - self.worker_use_torchrun = True + self.distributed_executor_backend = "torchrun" else: - self.worker_use_ray = True + from vllm.executor import ray_utils + ray_found = ray_utils.ray is not None + self.distributed_executor_backend = "ray" if ray_found else "mp" + self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if self.distributed_executor_backend not in ("ray", "mp", "torchrun", + None): + raise ValueError( + "Unrecognized distributed executor backend. Supported values " + "are 'ray' or 'mp' or 'torchrun'.") if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True @@ -535,7 +632,8 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported with pipeline parallelism.") - if self.ray_workers_use_nsight and not self.worker_use_ray: + if self.ray_workers_use_nsight and ( + not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") @@ -559,6 +657,7 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. + embedding_mode: Whether the running model is for embedding. """ def __init__( @@ -570,19 +669,33 @@ def __init__( num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + self.max_num_batched_tokens = 512 + elif embedding_mode: + # For embedding, choose specific value for higher throughput + self.max_num_batched_tokens = max( + max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self.embedding_mode = embedding_mode self._verify_args() @@ -649,6 +762,12 @@ def maybe_create_spec_config( target_dtype: str, speculative_model: Optional[str], num_speculative_tokens: Optional[int], + speculative_max_model_len: Optional[int], + enable_chunked_prefill: bool, + use_v2_block_manager: bool, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -666,13 +785,29 @@ def maybe_create_spec_config( model, if provided. num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. + speculative_max_model_len (Optional[int]): The maximum model len of + the speculative model. Used when testing the ability to skip + speculation for some sequences. + enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since its not + yet compatible with spec decode. + use_v2_block_manager (bool): Whether vLLM is configured to use the + v2 block manager or not. Used for raising an error since the v2 + block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. + ngram_prompt_lookup_max (Optional[int]): Max size of ngram token + window, if provided. + ngram_prompt_lookup_min (Optional[int]): Min size of ngram token + window, if provided. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if (speculative_model is None and num_speculative_tokens is None): + if speculative_model is None and num_speculative_tokens is None: return None if speculative_model is not None and num_speculative_tokens is None: @@ -681,41 +816,121 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + + assert (speculative_model is not None + and num_speculative_tokens is not None) + + if enable_chunked_prefill: + raise ValueError( + "Speculative decoding and chunked prefill are " + f"currently mutually exclusive ({enable_chunked_prefill=}).") + + if not use_v2_block_manager: + raise ValueError( + "Speculative decoding requires usage of the V2 " + "block manager. Enable it with --use-v2-block-manager.") + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None draft_code_revision = None draft_quantization = None - draft_max_model_len = None - - draft_model_config = ModelConfig( - model=speculative_model, - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - download_dir=target_model_config.download_dir, - load_format=target_model_config.load_format, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=draft_max_model_len, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, - max_logprobs=target_model_config.max_logprobs, - ) - draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - target_parallel_config)) + if speculative_model == "[ngram]": + if ngram_prompt_lookup_min is None: + ngram_prompt_lookup_min = 1 + if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: + raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") + if ngram_prompt_lookup_min < 1: + raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") + if ngram_prompt_lookup_min > ngram_prompt_lookup_max: + raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + draft_model_config = target_model_config + draft_parallel_config = target_parallel_config + else: + ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_min = 0 + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=None, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) return SpeculativeConfig( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_by_batch_size, + ngram_prompt_lookup_max, + ngram_prompt_lookup_min, + ) + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, ) @staticmethod @@ -730,7 +945,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=target_parallel_config.tensor_parallel_size, - worker_use_ray=target_parallel_config.worker_use_ray, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. @@ -748,6 +964,9 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ): """Create a SpeculativeConfig object. @@ -756,10 +975,19 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self._verify_args() @@ -783,7 +1011,10 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - draft_model = self.draft_model_config.model + if self.ngram_prompt_lookup_max > 0: + draft_model = "[ngram]" + else: + draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" @@ -792,11 +1023,13 @@ def __repr__(self) -> str: class LoRAConfig: max_lora_rank: int max_loras: int + fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h @@ -824,9 +1057,12 @@ def verify_with_model_config(self, model_config: ModelConfig): self.lora_dtype = model_config.dtype elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - if model_config.quantization is not None: - raise ValueError( - "LoRA is not supported with quantized models yet.") + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin and squeezellm + logger.warning("%s quantization is not tested with LoRA yet.", + model_config.quantization) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: @@ -886,7 +1122,7 @@ def get_image_input_enum_type( "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] +_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] # def _get_and_verify_dtype( @@ -905,6 +1141,7 @@ def _get_and_verify_dtype( if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 # models. + logger.info("Casting torch.float32 to torch.float16.") torch_dtype = torch.float16 else: torch_dtype = config_dtype @@ -917,25 +1154,19 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") - if is_hip() and torch_dtype == torch.float32: - rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() - if (k not in _ROCM_NOT_SUPPORTED_DTYPE) - ] - raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") - # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) pass else: # Casting between float16 and bfloat16 is allowed with a warning. - logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) return torch_dtype @@ -943,6 +1174,8 @@ def _get_and_verify_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -962,6 +1195,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] + # Choose the smallest "max_length" from the possible keys. max_len_key = None for key in possible_keys: max_len = getattr(hf_config, key, None) @@ -969,6 +1203,16 @@ def _get_and_verify_max_len( max_len_key = key if max_len < derived_max_model_len \ else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -978,12 +1222,19 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - f"{possible_keys}. Assuming the model's maximum length is " - f"{default_max_len}.") + "%d. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + if rope_scaling is not None and rope_scaling["type"] != "su": + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": @@ -991,14 +1242,23 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. if max_model_len is None: - max_model_len = derived_max_model_len + max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input # with model_max_length and allow this override when it's smaller. model_max_length = getattr(hf_config, "model_max_length", None) if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") pass else: raise ValueError( @@ -1011,6 +1271,37 @@ def _get_and_verify_max_len( return int(max_model_len) +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1022,9 +1313,11 @@ class EngineConfig: parallel_config: ParallelConfig scheduler_config: SchedulerConfig device_config: DeviceConfig + load_config: LoadConfig lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index ba061bbc4fbcb..26c704b8de901 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -20,6 +20,10 @@ class BlockTable: _blocks (Optional[List[Block]], optional): An optional list of existing blocks to initialize the BlockTable with. If not provided, an empty BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequance. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. Attributes: _block_size (int): The maximum number of tokens that can be stored in a @@ -37,11 +41,15 @@ def __init__( block_size: int, block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, ): self._block_size = block_size self._allocator = block_allocator - self._blocks: Optional[List[Block]] = _blocks + if _blocks is None: + _blocks = [] + self._blocks: List[Block] = _blocks + self._max_block_sliding_window = max_block_sliding_window # Use helper method instead of directly calculating, as blocks # may not be allocated. self._num_full_slots = len(self._get_all_token_ids()) @@ -87,7 +95,8 @@ def allocate(self, def append_token_ids(self, token_ids: List[int], - num_lookahead_slots: int = 0) -> None: + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -102,13 +111,35 @@ def append_token_ids(self, Args: token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. """ - assert self._is_allocated - assert token_ids, "can't append empty token ids" - + assert self._is_allocated, "no blocks have been allocated" + assert len(self._blocks) > 0 + + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) + # Update the blocks with the new tokens blocks = self._blocks[self._num_full_slots // self._block_size:] token_blocks = self._chunk_token_blocks_for_append(token_ids) @@ -141,6 +172,7 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 self._blocks.append( self._allocator.allocate_mutable(prev_block=self._blocks[-1], device=device)) @@ -159,11 +191,13 @@ def fork(self) -> "BlockTable": the current instance. """ assert self._is_allocated + assert len(self._blocks) > 0 forked_blocks = self._allocator.fork(self._blocks[-1]) return BlockTable( block_size=self._block_size, block_allocator=self._allocator, _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, ) def free(self) -> None: @@ -177,10 +211,10 @@ def free(self) -> None: assert self._is_allocated for block in self._blocks: self._allocator.free(block) - self._blocks = None + self._blocks = [] @property - def physical_block_ids(self) -> List[int]: + def physical_block_ids(self) -> List[Optional[int]]: """Returns a list of physical block indices for the blocks in the BlockTable. @@ -235,7 +269,7 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _get_all_token_ids(self) -> List[int]: # NOTE: This function is O(seq_len); use sparingly. - token_ids = [] + token_ids: List[int] = [] if not self._is_allocated: return token_ids @@ -247,7 +281,7 @@ def _get_all_token_ids(self) -> List[int]: @property def _is_allocated(self) -> bool: - return self._blocks is not None + return len(self._blocks) > 0 @property def _num_empty_slots(self) -> int: diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 50c70533c4fbc..4d7a12165cb01 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,5 +1,4 @@ -from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -7,7 +6,19 @@ RefCount = int -class RefCounter: +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): """A class for managing reference counts for a set of block indices. The RefCounter class maintains a dictionary that maps block indices to their @@ -54,7 +65,7 @@ def as_readonly(self) -> "ReadOnlyRefCounter": return ReadOnlyRefCounter(self) -class ReadOnlyRefCounter: +class ReadOnlyRefCounter(RefCounterProtocol): """A read-only view of the RefCounter class. The ReadOnlyRefCounter class provides a read-only interface to access the @@ -96,10 +107,10 @@ class CopyOnWriteTracker: def __init__( self, - refcounter: RefCounter, + refcounter: RefCounterProtocol, allocator: BlockAllocator, ): - self._copy_on_writes = defaultdict(list) + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] self._refcounter = refcounter self._allocator = allocator @@ -138,25 +149,27 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: prev_block=block.prev_block).block_id # Track src/dst copy. - self._copy_on_writes[src_block_id].append(block_id) + assert src_block_id is not None + assert block_id is not None + self._copy_on_writes.append((src_block_id, block_id)) return block_id - def clear_cows(self) -> Dict[BlockId, List[BlockId]]: + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: """Clears the copy-on-write tracking information and returns the current state. - This method returns a dictionary mapping source block indices to lists - of destination block indices for the current copy-on-write operations. + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. It then clears the internal tracking information. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices for the + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the current copy-on-write operations. """ - cows = dict(self._copy_on_writes) - self._copy_on_writes.clear() + cows = self._copy_on_writes + self._copy_on_writes = [] return cows @@ -180,6 +193,6 @@ def recurse(block: Block, lst: List[Block]) -> None: recurse(block.prev_block, lst) lst.append(block) - all_blocks = [] + all_blocks: List[Block] = [] recurse(last_block, all_blocks) return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3135e194c5937..d28a684376974 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional +from typing import Dict, FrozenSet, List, Optional, Tuple -from vllm.core.block.interfaces import (Block, BlockAllocator, +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator @@ -57,15 +57,15 @@ def create( cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": - gpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, ) - cpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, @@ -105,11 +105,19 @@ def __init__( Device.GPU: gpu_block_allocator, } - self._block_ids_to_allocator = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable(None, Device.GPU)) + return self._null_block + def allocate_mutable(self, prev_block: Optional[Block], device: Device) -> Block: """Allocates a new mutable block on the specified device. @@ -149,7 +157,12 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ - allocator = self._block_ids_to_allocator[block.block_id] + # Null block should never be freed + if isinstance(block, NullBlock): + return + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.free(block) def fork(self, last_block: Block) -> List[Block]: @@ -163,7 +176,11 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ - allocator = self._block_ids_to_allocator[last_block.block_id] + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.fork(last_block) def get_num_free_blocks(self, device: Device) -> int: @@ -171,29 +188,40 @@ def get_num_free_blocks(self, device: Device) -> int: Args: device (Device): The device for which to query the number of free - blocks. + blocks. AssertionError is raised if None is passed. Returns: int: The number of free blocks available on the specified device. """ return self._allocators[device].get_num_free_blocks() - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. Returns: - Dict[int, List[int]]: A dictionary mapping source block IDs to lists - of destination block IDs. + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. """ # CoW only supported on GPU device = Device.GPU return self._allocators[device].clear_copy_on_writes() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" # Prefix caching only supported on GPU. device = Device.GPU - return self._allocators[device].mark_blocks_as_computed() + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: @@ -202,5 +230,73 @@ def get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids( seq_block_ids) - def all_block_ids(self) -> frozenset[int]: + @property + def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + raise NotImplementedError + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 9f466566f096b..8fc4c601106cd 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,8 +1,10 @@ -from abc import ABC, abstractmethod, abstractproperty -from typing import Dict, List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import FrozenSet, List, Optional, Protocol, Tuple from vllm.utils import Device +BlockId = int + class Block(ABC): @@ -10,26 +12,58 @@ class Block(ABC): def append_token_ids(self, token_ids: List[int]) -> None: pass - @abstractproperty + @property + @abstractmethod def block_id(self) -> Optional[int]: pass - @abstractproperty + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + + @property + @abstractmethod def token_ids(self) -> List[int]: pass - @abstractproperty + @property + @abstractmethod def num_empty_slots(self) -> int: pass - @abstractproperty + @property + @abstractmethod def is_full(self) -> bool: pass - @abstractproperty + @property + @abstractmethod def prev_block(self) -> Optional["Block"]: pass + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + class Factory(Protocol): @abstractmethod @@ -43,6 +77,17 @@ def __call__( ) -> "Block": pass + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + class BlockAllocator(ABC): @@ -63,20 +108,30 @@ def free(self, block: Block) -> None: def fork(self, last_block: Block) -> List[Block]: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def get_num_free_blocks(self) -> int: pass - @abstractproperty - def all_block_ids(self) -> frozenset[int]: + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: pass @abstractmethod - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: pass @abstractmethod - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass @abstractmethod @@ -84,11 +139,21 @@ def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + class NoFreeBlocksError(ValueError): pass -class DeviceAwareBlockAllocator(BlockAllocator): +class DeviceAwareBlockAllocator(ABC): @abstractmethod def allocate_mutable(self, prev_block: Optional[Block], @@ -103,3 +168,47 @@ def allocate_immutable(self, prev_block: Optional[Block], @abstractmethod def get_num_free_blocks(self, device: Device) -> int: pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index f8e9265bb2d67..ae01930878254 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,10 +1,9 @@ -from typing import Dict, Iterable, List, Optional, Set +from typing import FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -BlockId = int Refcount = int @@ -49,8 +48,10 @@ def __init__( allocator=self, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -63,11 +64,14 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The newly allocated immutable block. """ + assert device is None block = self.allocate_mutable(prev_block=prev_block) block.append_token_ids(token_ids) return block - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a new mutable block, linked to the previous block. Args: @@ -78,6 +82,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: Returns: Block: The newly allocated mutable block. """ + assert device is None block_id = self._allocate_new_block_id() return self._create_block( prev_block=prev_block, @@ -88,6 +93,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: ) def free(self, block: Block) -> None: + assert block.block_id is not None self._free_block_id(block.block_id) # Mark the block as having no allocation. @@ -111,6 +117,7 @@ def fork(self, last_block: Block) -> List[Block]: for block in source_blocks: # Increment refcount for each block. + assert block.block_id is not None refcount = self._refcounter.incr(block.block_id) assert refcount != 1, "can't fork free'd block" @@ -129,6 +136,9 @@ def fork(self, last_block: Block) -> List[Block]: def get_num_free_blocks(self) -> int: return len(self._free_block_indices) + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + def _allocate_new_block_id(self) -> BlockId: if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() @@ -148,7 +158,7 @@ def refcounter(self): return self._refcounter @property - def all_block_ids(self): + def all_block_ids(self) -> FrozenSet[int]: return self._all_block_indices def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: @@ -165,16 +175,25 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching. Since the naive allocator does not implement prefix caching, we do @@ -191,6 +210,9 @@ def get_common_computed_block_ids( """ return [] + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix @@ -215,13 +237,13 @@ class NaiveBlock(Block): """ def __init__(self, - prev_block: Block, + prev_block: Optional[Block], token_ids: List[int], block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, _cow_target: Optional[Block] = None): - self._token_ids = [] + self._token_ids: List[int] = [] self._block_size = block_size self._prev_block = prev_block self._block_id = block_id @@ -247,6 +269,22 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: assert self.num_empty_slots >= len(token_ids) self._token_ids.extend(token_ids) + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + @property def block_id(self) -> Optional[int]: return self._block_id @@ -267,9 +305,14 @@ def num_empty_slots(self) -> int: def token_ids(self) -> List[int]: return self._token_ids + @property def block_size(self) -> int: return self._block_size @property def prev_block(self) -> Optional["Block"]: return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8abb80a..4eb32f145b05b 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,15 +1,20 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, Iterable, List, Optional +from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor PrefixHash = int -BlockId = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 class PrefixCachingBlockAllocator(BlockAllocator): @@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator): from 0 to num_blocks - 1. """ - # TODO last access time / evictor integration - def __init__( self, num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset - # of self._cached_blocks. - self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + # A mapping of blockId to Block to track those cached blocks + self._blocks: Dict[BlockId, Block] = {} # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( - create_block=self._create_block, + create_block=self._create_block, # type: ignore num_blocks=num_blocks, block_size=block_size, block_ids=block_ids, @@ -54,6 +56,10 @@ def __init__( self._block_size = block_size + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.evictor: Evictor = make_evictor(eviction_policy) + # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable # blocks. @@ -72,6 +78,7 @@ def _create_block( block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, + computed: bool = False, ) -> Block: # Bind block to self. allocator = self @@ -82,10 +89,13 @@ def _create_block( block_size=block_size, block_id=block_id, prefix_caching_allocator=allocator, + computed=computed, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -96,6 +106,7 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The allocated immutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) block = self._create_block( @@ -109,65 +120,95 @@ def allocate_immutable(self, prev_block: Optional[Block], cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: block.block_id = cached_block_id - self._incr_refcount_cached_block(block.content_hash, - block.block_id) + self._incr_refcount_cached_block(block, block.block_id) return block block = self.allocate_mutable(prev_block) block.append_token_ids(token_ids) assert block.content_hash is not None - # TODO computed bit return block - def allocate_mutable(self, prev_block: Block) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a mutable block. If there are no free blocks, this will evict unused cached blocks. Args: prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. Returns: Block: The allocated mutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) try: - return self._hashless_allocator.allocate_mutable( + block = self._hashless_allocator.allocate_mutable( prev_block=prev_block) + + assert block.block_id not in self._blocks + assert block.block_id is not None + self._blocks[block.block_id] = block + return block except BlockAllocator.NoFreeBlocksError: # We must check the unused cached blocks before raising OOM. pass - if self._unused_cached_blocks: - # TODO policy for selecting block to remove - content_hash_to_evict = next(iter(self._unused_cached_blocks)) + # If the evictor has blocks available for eviction, evict a block + # and return it. + if self.evictor.num_blocks > 0: + # here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id - # Clear content hash mapping; the block will be overwritten. - del self._cached_blocks[content_hash_to_evict] + self._cached_blocks.pop(content_hash_to_evict) - block_id = self._unused_cached_blocks.pop(content_hash_to_evict) - refcount = self._refcounter.incr(block_id) - assert refcount == 1 + self._refcounter.incr(block_id) + + # the block comes from evictor already contain computed result block = self._create_block( prev_block=prev_block, token_ids=[], block_size=self._block_size, allocator=self, block_id=block_id, + computed=True, ) assert block.content_hash is None + + assert block.block_id not in self._blocks + assert block.block_id is not None + self._blocks[block.block_id] = block return block # No block available in hashless allocator, nor in unused cache blocks. raise BlockAllocator.NoFreeBlocksError() - def _incr_refcount_cached_block(self, content_hash: int, + def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: + # now _incr_refcount_cached_block comes from two place + # allocate_immutable/promote_to_immutable_block where hit + # _cached_blocks hash key. + # In both cases, it means that already exists a already + # computed block which shared with block now + block.computed = True + refcount = self._refcounter.incr(block_id) if refcount == 1: - assert content_hash in self._unused_cached_blocks - del self._unused_cached_blocks[content_hash] + # if block get referred, then it shall not be in evictor + # and put it into _blocks for tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + self._blocks[block_id] = block def free(self, block: Block) -> None: """Decrement the refcount of the block. If the decremented refcount is @@ -180,22 +221,37 @@ def free(self, block: Block) -> None: is not None), "freeing unallocated block is undefined" self._free_block_id_for_block(block.block_id, block) + block.block_id = None def _free_block_id_for_block(self, block_id: BlockId, block: Block) -> None: assert isinstance(block, PrefixCachingBlock) - if block.content_hash is None: + # if we comes from promote_to_immutable_block, it means that + # block.content_hash is never None. + # However we need to release the same content block, so that + # physical block could get reused. + if block.block_id != block_id or block.content_hash is None: + refcount = self._refcounter.get(block_id) + # We have fork case where block would get more than one ref, + # so we cannot free it from tracking if ref cnt large than 1 + assert block.block_id is not None + refcount = self._refcounter.get(block.block_id) + if refcount == 1: + del self._blocks[block.block_id] + return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) - # If no longer used, add the block to the unused cached blocks. + # If no longer used, add the block to the evictor. if refcount == 0: - assert block.content_hash not in self._unused_cached_blocks assert block.content_hash in self._cached_blocks - self._unused_cached_blocks[block.content_hash] = block_id + assert block.block_id is not None + del self._blocks[block.block_id] + self.evictor.add(block.block_id, block.content_hash, + block.num_tokens_total, block.last_accessed) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -228,18 +284,21 @@ def fork(self, last_block: Block) -> List[Block]: return forked_blocks - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None # The number of free blocks is the number of hashless free blocks - # plus the number of hashful blocks that are unused. - return self._hashless_allocator.get_num_free_blocks() + len( - self._unused_cached_blocks) + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() @property - def all_block_ids(self) -> frozenset[int]: + def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids - def promote_to_immutable_block(self, - block: "PrefixCachingBlock") -> BlockId: + def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable block. This means that its content can be referenced by future blocks having the same prefix. @@ -249,7 +308,7 @@ def promote_to_immutable_block(self, block. Args: - block (PrefixCachingBlock): The mutable block to be promoted. + block: The mutable block to be promoted. Returns: BlockId: Either the original block index, or the block index of @@ -264,9 +323,10 @@ def promote_to_immutable_block(self, if block.content_hash not in self._cached_blocks: self._cached_blocks[block.content_hash] = block.block_id else: - self._free_block_id_for_block(block.block_id, block) + self._free_block_id_for_block( + self._cached_blocks[block.content_hash], block) self._incr_refcount_cached_block( - block.content_hash, self._cached_blocks[block.content_hash]) + block, self._cached_blocks[block.content_hash]) return self._cached_blocks[block.content_hash] @@ -284,38 +344,72 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if block_id in self._blocks: + self._blocks[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching.""" - # TODO Track computed blocks. - pass + + for block_id in block_ids: + if block_id in self._blocks: + # only those full block is valid for prefix caching + if self._blocks[block_id].is_full: + self._blocks[block_id].computed = True + elif block_id not in self.evictor: + raise ValueError(f"Mark {block_id=} as computed which " + "is not belonged to GPU") + + def block_is_computed(self, block_id: int) -> bool: + if block_id in self._blocks: + return self._blocks[block_id].computed + else: + return block_id in self.evictor def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. - Used in prefill (can skip prefill of some blocks). + Only those blocks that are immutable and already be marked + compyted would be taken consideration. """ - # TODO: Track computed blocks. - computed = lambda block_id: False - # NOTE We exclude the last block to avoid the case where the entire # prompt is cached. This would cause erroneous behavior in model # runner. + ids_list = [ - takewhile(lambda block_id: computed(block_id), seq[:-1]) - for seq in seq_block_ids + list( + takewhile(lambda block_id: self.block_is_computed(block_id), + seq[:-1])) for seq in seq_block_ids ] - return commonprefix([ids for ids in ids_list if ids != []]) + # It returns a list of int although type annotation says list of string. + return commonprefix([ + ids for ids in ids_list # type: ignore + if ids != [] + ]) class PrefixCachingBlock(Block): @@ -332,7 +426,7 @@ class PrefixCachingBlock(Block): token_ids (List[int]): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. - prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix + prefix_caching_allocator (BlockAllocator): The prefix caching block allocator associated with this block. block_id (Optional[int], optional): The physical block index of this block. Defaults to None. @@ -340,17 +434,25 @@ class PrefixCachingBlock(Block): def __init__( self, - prev_block: Optional["PrefixCachingBlock"], + prev_block: Optional[Block], token_ids: List[int], block_size: int, - prefix_caching_allocator: PrefixCachingBlockAllocator, + prefix_caching_allocator: BlockAllocator, block_id: Optional[int] = None, + computed: bool = False, ): + assert isinstance(prefix_caching_allocator, + PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator.") assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed self._block = NaiveBlock( prev_block=prev_block, @@ -361,6 +463,22 @@ def __init__( _cow_target=self, ) + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + def append_token_ids(self, token_ids: List[int]) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. @@ -398,6 +516,27 @@ def is_full(self) -> bool: def num_empty_slots(self) -> int: return self._block.num_empty_slots + @property + def num_tokens_total(self) -> int: + """return the total tokens so far. + + Here we iterate the block chain till to the first block, while + cache the result in local to prevent repeated computations. + """ + if self._cached_num_tokens_total is not None: + return self._cached_num_tokens_total + + _block: Optional[Block] = self + self._cached_num_tokens_total = 0 + + # TODO: current implement here take O(N^2), we expect future + # we have O(1) here + while _block is not None: + self._cached_num_tokens_total += len(_block.token_ids) + _block = _block.prev_block + + return self._cached_num_tokens_total + @property def block_size(self) -> int: return self._block.block_size @@ -428,8 +567,10 @@ def content_hash(self) -> Optional[int]: return None is_first_block = self._prev_block is None - prev_block_hash = (None if is_first_block else - self._prev_block.content_hash) + prev_block_hash = ( + None if is_first_block else + self._prev_block.content_hash # type: ignore + ) # Previous block exists but does not yet have a hash. # Return no hash in this case. diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000000000..2c412a8f472e0 --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,56 @@ +"""Block manager utils.""" +from vllm.sequence import SequenceGroup + +# Exception strings for non-implemented block manager enc/dec scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + + +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e7e3b4dc1e9b4..201cba309f6ef 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,11 +1,15 @@ """A block manager that manages token blocks.""" +import math from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -44,6 +48,10 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def contains_block(self, block_hash: int) -> bool: pass @@ -128,6 +136,9 @@ def get_num_free_blocks(self) -> int: return (self.num_blocks - self.current_num_blocks + self.evictor.num_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor @@ -187,6 +198,9 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: return len(self.free_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: raise NotImplementedError( "Invalid codepath for uncached block allocator.") @@ -218,9 +232,9 @@ def __init__( self.block_sliding_window = None if sliding_window is not None: - assert sliding_window % block_size == 0, (sliding_window, - block_size) - self.block_sliding_window = sliding_window // block_size + # Round up to nearest block size to regularize sliding window + # allocation sizes. + self.block_sliding_window = math.ceil(sliding_window / block_size) self.watermark = watermark assert watermark >= 0.0 @@ -231,10 +245,10 @@ def __init__( if self.enable_caching: logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) @@ -242,14 +256,30 @@ def __init__( Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.cross_block_tables: Dict[str, BlockTable] = {} + + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + return 0 if seq is None \ + else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = len(seq.logical_token_blocks) + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + self_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -263,11 +293,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - + def _allocate_sequence(self, \ + seq: Sequence, \ + ref_count: int, \ + is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -277,21 +306,46 @@ def allocate(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - elif self.enable_caching: + block.ref_count = ref_count + elif not is_encoder_decoder and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count block_table.append(block) - # Assign the block table for each sequence. + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + is_encoder_decoder = seq_group.is_encoder_decoder() + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + block_table: BlockTable = \ + self._allocate_sequence(seq, + seq_group.num_seqs(), + is_encoder_decoder) + + # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + # Allocate encoder sequence + if is_encoder_decoder: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: @@ -373,7 +427,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int = 0, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -388,11 +442,11 @@ def append_slots( block_table.append(block_table[len(block_table) % self.block_sliding_window]) else: - # The sequence has a new logical block. + # The sequence hash a new logical block. # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) - return {} + return [] # We want to append the token to the last physical block. last_block = block_table[-1] @@ -405,7 +459,7 @@ def append_slots( maybe_new_block = self._maybe_promote_last_block( seq, last_block) block_table[-1] = maybe_new_block - return {} + return [] else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. @@ -413,7 +467,7 @@ def append_slots( block_table[-1] = new_block self.gpu_allocator.free(last_block) - return {last_block.block_number: [new_block.block_number]} + return [(last_block.block_number, new_block.block_number)] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. @@ -430,89 +484,117 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. + request_id = seq_group.request_id blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) + # Cross-attention blocks + if seq_group.is_encoder_decoder(): + blocks.update(self.cross_block_tables[request_id]) return list(blocks) def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: + num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.is_encoder_decoder(): + num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). num_required_blocks = len(blocks) + num_swapped_seqs - return num_free_blocks - num_required_blocks >= self.watermark_blocks + if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: + return AllocStatus.NEVER + elif num_free_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _swap_block_table( + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table = [] + + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 + else: + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) + + return new_block_table def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> Dict[int, int]: + num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + request_id = seq_group.request_id + # CPU block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - return block_number_mapping + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + # GPU block -> CPU block. + # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - return block_number_mapping + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up @@ -537,15 +619,32 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.cross_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.cross_block_tables[seq_group.request_id] + def reset(self) -> None: + # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): + self._free_block_table(block_table) + self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() @@ -588,7 +687,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]: for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Return the block ids that are common for a given sequence group. Used in prefill (can skip prefill of some blocks). diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 813e71ad883b2..cad42ab3c1ba2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,13 +1,17 @@ """A block manager that manages token blocks.""" from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device SeqId = int +EncoderSeqId = str class BlockSpaceManagerV2(BlockSpaceManager): @@ -64,42 +68,57 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - assert sliding_window is None, "Sliding window not yet supported" - - self.block_sliding_window = None + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 self.watermark = watermark assert watermark >= 0.0 - assert not enable_caching, "Prefix caching not yet supported" self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) self.block_allocator = CpuGpuBlockAllocator.create( - # Currently, only naive blocks are supported (no prefix caching). - allocator_type="naive", + allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) - assert self.block_sliding_window is None - if self.block_sliding_window is not None: + if seq_group.is_encoder_decoder(): + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.get_encoder_seq().get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) @@ -113,7 +132,19 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + block_table.allocate(seq.get_token_ids()) + + return block_table + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -121,19 +152,29 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - ) - assert self.block_sliding_window is None - block_table.allocate(seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + block_table = self._allocate_sequence(seq_group.get_encoder_seq()) + self.cross_block_tables[request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue @@ -167,13 +208,14 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] block_table.append_token_ids( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), ) # Return any new copy-on-writes. @@ -187,25 +229,52 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) - return block_ids + return block_ids # type: ignore - def access_all_blocks_in_seq(self, seq, now): - # TODO add prefix caching support. - # Tracked here https://github.com/vllm-project/vllm/issues/3667 - pass + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + # Update the last accessed time of all the blocks accessed + # in this step. + # And the accessed time is only useful for prefix caching now, + # as it support internal evictor policy for which cached + # block could be refilled, to keep cached content could be reused + # at max extend. + if self.enable_caching: + block_table = self.block_tables[seq.seq_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed( + block_ids, # type: ignore + now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # We ignore the sequence group as its not necessary. After the batch is - # formed by the scheduler, we do not need to mark blocks from individual - # sequence groups as computed -- all blocks in the batch can be marked - # as computed. - self.block_allocator.mark_blocks_as_computed() + # The only need for mark block as computed is for prefix caching, + # while currently we could determine whether one block is computed + # or not by check whether it has content hash. + # So this function is useless for block_v2. + pass - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Determine which blocks for which we skip prefill. With prefix caching we can skip prefill for previously-generated blocks. @@ -218,25 +287,26 @@ def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: seq_block_ids = [ self.block_tables[seq.seq_id].physical_block_ids for seq in seqs ] + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( - seq_block_ids) + seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return False + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: return False - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: raise NotImplementedError def get_num_free_gpu_blocks(self) -> int: diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py new file mode 100644 index 0000000000000..a09d79ec3c420 --- /dev/null +++ b/vllm/core/embedding_model_block_manager.py @@ -0,0 +1,84 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup + + +class EmbeddingModelBlockSpaceManager(BlockSpaceManager): + """An embedding version of BlockSpaceManager for use in environments + with embedding models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/core/evictor.py b/vllm/core/evictor_v1.py similarity index 100% rename from vllm/core/evictor.py rename to vllm/core/evictor_v1.py diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py new file mode 100644 index 0000000000000..57759b29347f4 --- /dev/null +++ b/vllm/core/evictor_v2.py @@ -0,0 +1,127 @@ +import enum +from abc import ABC, abstractmethod, abstractproperty +from typing import OrderedDict, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @abstractproperty + def num_blocks(self) -> int: + pass + + +class BlockMetaData(): + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block = next(iter(self.free_table.values())) + evicted_block_id = next(iter(self.free_table.keys())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _id, block in self.free_table.items(): + if evicted_block.last_accessed > block.last_accessed or ( + evicted_block.last_accessed == block.last_accessed and + evicted_block.num_hashed_tokens < block.num_hashed_tokens): + evicted_block = block + evicted_block_id = _id + + self.free_table.pop(evicted_block_id) + + return evicted_block_id, evicted_block.content_hash + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 711536bcc97be..689cbc2179ee1 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,6 +1,8 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List +from typing import List +from typing import Sequence as GenericSequence +from typing import Tuple from vllm.sequence import Sequence, SequenceGroup @@ -33,6 +35,11 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 + if version == "embedding": + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + raise ValueError(f"Unknown version {version=}") @abstractmethod @@ -53,7 +60,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: pass @abstractmethod @@ -62,12 +69,12 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: @abstractmethod def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + num_lookahead_slots: int) -> AllocStatus: pass @abstractmethod def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: pass @abstractmethod @@ -75,7 +82,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: pass @abstractmethod @@ -103,7 +110,8 @@ def access_all_blocks_in_seq( pass @abstractmethod - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f9374960..7c70b1b244f7d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,4 +1,6 @@ import enum +import os +import random import time from collections import deque from dataclasses import dataclass, field @@ -11,10 +13,16 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.utils import merge_dicts logger = init_logger(__name__) +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + class PreemptionMode(enum.Enum): """Preemption modes. @@ -42,8 +50,8 @@ class SchedulingBudget: """ token_budget: int max_num_seqs: int - _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) - _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 @@ -109,16 +117,19 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -133,14 +144,18 @@ def is_empty(self) -> bool: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) - def _sort_by_lora_ids(self) -> bool: + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass @@ -160,9 +175,9 @@ class SchedulerRunningOutputs: # Sequences that are swapped out. swapped_out: List[SequenceGroup] # The blocks to swap out. - blocks_to_swap_out: Dict[int, int] + blocks_to_swap_out: List[Tuple[int, int]] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int @@ -173,8 +188,8 @@ def create_empty(cls) -> "SchedulerRunningOutputs": prefill_seq_groups=[], preempted=[], swapped_out=[], - blocks_to_swap_out={}, - blocks_to_copy={}, + blocks_to_swap_out=[], + blocks_to_copy=[], num_lookahead_slots=0, ) @@ -192,20 +207,23 @@ class SchedulerSwappedInOutputs: # phase. I.e., it means the prefill has been chunked. prefill_seq_groups: List[SequenceGroup] # The blocks to swap in. - blocks_to_swap_in: Dict[int, int] + blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] @classmethod def create_empty(cls) -> "SchedulerSwappedInOutputs": return SchedulerSwappedInOutputs( decode_seq_groups=[], prefill_seq_groups=[], - blocks_to_swap_in={}, - blocks_to_copy={}, + blocks_to_swap_in=[], + blocks_to_copy=[], num_lookahead_slots=0, + infeasible_seq_groups=[], ) @@ -246,16 +264,14 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - if self.scheduler_config.chunked_prefill_enabled: - self.prompt_limit = self.scheduler_config.max_model_len - else: - self.prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") + version) # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( @@ -282,6 +298,14 @@ def __init__( # Latency of the last prompt step self.last_prompt_latency = 0.0 + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + @property def lora_enabled(self) -> bool: return bool(self.lora_config) @@ -293,7 +317,6 @@ def num_decoding_tokens_per_seq(self) -> int: def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. - logger.debug(f"add_seq_group {seq_group.request_id}") self.waiting.append(seq_group) def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: @@ -317,7 +340,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: for seq_group in state_queue: if not request_ids: # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity . + # but is acceptable to reduce complexity. break if seq_group.request_id in request_ids: # Appending aborted group into pending list. @@ -333,7 +356,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -368,8 +392,8 @@ def _schedule_running( scheduling and SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -382,25 +406,23 @@ def _schedule_running( # groups to preempt. now = time.time() running_queue = policy.sort_by_priority(now, running_queue) - while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - # We can have up to 1 running prefill at any given time in running - # queue, which means we can guarantee chunk size is at least 1. - assert num_running_tokens != 0 - num_running_seqs = seq_group.get_max_num_running_seqs() + if num_running_tokens == 0: + break running_queue.popleft() while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.pop(seq_group.lora_int_id) + curr_loras.remove(seq_group.lora_int_id) if running_queue: # Preempt the lowest-priority sequence groups. @@ -422,7 +444,6 @@ def _schedule_running( swapped_out.append(seq_group) break else: - logger.debug(f"append slot for {seq_group}") self._append_slots(seq_group, blocks_to_copy) is_prefill = seq_group.is_prefill() if is_prefill: @@ -436,13 +457,16 @@ def _schedule_running( token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_running_tokens) - budget.add_num_seqs(seq_group.request_id, num_running_seqs) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - # Make sure all queues are updated. - assert len(running_queue) == 0 - return running_queue, SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, @@ -485,25 +509,39 @@ def _schedule_swapped( SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) + infeasible_seq_groups: List[SequenceGroup] = [] - leftover_swapped = deque() + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): + alloc_status = self.block_manager.can_swap_in(seq_group) + if alloc_status == AllocStatus.LATER: break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. @@ -534,7 +572,6 @@ def _schedule_swapped( ScheduledSequenceGroup(seq_group, token_chunk_size=num_new_tokens)) else: - assert num_new_tokens == 1 decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) @@ -548,7 +585,24 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False)) + is_prefill=False), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if self.scheduler_config.chunked_prefill_enabled: + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if (seq_group.lora_request + and seq_group.lora_request.long_lora_max_len): + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit def _schedule_prefills( self, @@ -589,7 +643,7 @@ def _schedule_prefills( # Copy the queue so that the input queue is not modified. waiting_queue = deque([s for s in waiting_queue]) - leftover_waiting_sequences = deque() + leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -604,10 +658,11 @@ def _schedule_prefills( num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens - if num_new_tokens > self.prompt_limit: + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: logger.warning( - f"Input prompt ({num_new_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", num_new_tokens, prompt_limit) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -620,8 +675,9 @@ def _schedule_prefills( break elif can_allocate == AllocStatus.NEVER: logger.warning( - f"Input prompt ({num_new_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + "Input prompt (%d tokens) is too long" + " and exceeds the capacity of block_manager", + num_new_tokens) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -631,6 +687,8 @@ def _schedule_prefills( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None if (self.lora_enabled and lora_int_id > 0 and lora_int_id not in curr_loras and len(curr_loras) >= self.lora_config.max_loras): @@ -650,7 +708,7 @@ def _schedule_prefills( if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - self._allocate_and_set_running(seq_group, num_new_tokens) + self._allocate_and_set_running(seq_group) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -670,7 +728,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - The current policy is designed to opimimize the throughput. First, + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. @@ -686,8 +744,8 @@ def _schedule_default(self) -> SchedulerOutputs: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -737,6 +795,8 @@ def _schedule_default(self) -> SchedulerOutputs: # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -750,12 +810,13 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), - ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, ) def _schedule_chunked_prefill(self): @@ -776,7 +837,7 @@ def _schedule_chunked_prefill(self): token_budget=self.scheduler_config.max_num_batched_tokens, max_num_seqs=self.scheduler_config.max_num_seqs, ) - curr_loras = set() + curr_loras: Set[int] = set() remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -826,25 +887,25 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -858,6 +919,13 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + # Appending slots only occurs in decoding. is_prefill = False @@ -866,15 +934,6 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def _can_swap_in(self, seq_group: SequenceGroup) -> bool: - # Swapping in is considered decode. - is_prefill = False - - return self.block_manager.can_swap_in( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler @@ -905,15 +964,31 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.block_manager.get_common_computed_block_ids( seq_group.get_seqs(status=SequenceStatus.RUNNING))) + do_sample = True + if seq_group.is_prefill(): + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + seqs[0].data.get_len()): + do_sample = False + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, @@ -948,8 +1023,7 @@ def free_finished_seq_groups(self) -> None: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) - def _allocate_and_set_running(self, seq_group: SequenceGroup, - num_new_tokens: int) -> None: + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING @@ -957,32 +1031,29 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup, def _append_slots( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: List[Tuple[int, int]], ) -> None: """Appends new slots to the sequences in the given sequence group. Args: seq_group (SequenceGroup): The sequence group containing the sequences to append slots to. - blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source - block indices to lists of destination block indices. This - dictionary is updated with the new source and destination block - indices for the appended slots. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) - - for src, dests in cows.items(): - if src not in blocks_to_copy: - blocks_to_copy[src] = [] - blocks_to_copy[src].extend(dests) + blocks_to_copy.extend(cows) def _preempt( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], preemption_mode: Optional[PreemptionMode] = None, ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: @@ -1001,6 +1072,17 @@ def _preempt( preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + if preemption_mode == PreemptionMode.RECOMPUTE: self._preempt_by_recompute(seq_group) elif preemption_mode == PreemptionMode.SWAP: @@ -1023,24 +1105,24 @@ def _preempt_by_recompute( def _preempt_by_swap( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) def _swap_in( self, seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], + blocks_to_swap_in: List[Tuple[int, int]], ) -> None: mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) + blocks_to_swap_in.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): seq.status = SequenceStatus.RUNNING def _swap_out( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: if not self.block_manager.can_swap_out(seq_group): # FIXME(woosuk): Abort the sequence group instead of aborting the @@ -1049,7 +1131,7 @@ def _swap_out( "Aborted due to the lack of CPU swap space. Please increase " "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) + blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED @@ -1084,7 +1166,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int: def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> Tuple[int, bool]: + budget: SchedulingBudget) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. @@ -1092,11 +1174,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case. diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000000000..db325cfabf55e --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py new file mode 100644 index 0000000000000..b04adb532dd38 --- /dev/null +++ b/vllm/distributed/communication_op.py @@ -0,0 +1,331 @@ +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.distributed import ProcessGroup + +from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_ca_communicator, + get_tp_pynccl_communicator) +from vllm.utils import is_hip + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + ca_comm = get_tp_ca_communicator() + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor + # size is too large, it will fallback to the next available option. + # + # The current status on ROCm is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | disabled| disabled| + # PyNccl | disabled| disabled| + # torch.distributed | enabled | enabled | + # + # In summary: + # On ROCm, we use PyTorch NCCL at all times. + # On CUDA, when using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. + tp_pynccl_comm = get_tp_pynccl_communicator() + pp_pynccl_comm = get_pp_pynccl_communicator() + if not tp_pynccl_comm: + maybe_tp_pynccl_context = nullcontext() + else: + maybe_tp_pynccl_context = tp_pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + if not pp_pynccl_comm: + maybe_pp_pynccl_context = nullcontext() + else: + maybe_pp_pynccl_context = pp_pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_tp_pynccl_context, maybe_pp_pynccl_context: + yield graph_capture_context + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group. + + NOTE: This operation will be applied in-place on the input tensor if + disable_custom_all_reduce is set to True. Otherwise, this operation may or + may not be applied in place depending on whether custom all reduce is + invoked for a particular tensor, which further depends on the tensor size + and GPU topology. + + TLDR: always assume this function modifies its input, but use the return + value as the output. + """ + ca_comm = get_tp_ca_communicator() + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + else: + torch.distributed.all_reduce(input_, + group=get_tensor_model_parallel_group()) + return input_ + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=get_tensor_model_parallel_group()) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: + """Gather the input tensor across model parallel group. + + NOTE: We assume that the input tensor is on the same device across + all the ranks. + """ + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if get_tensor_model_parallel_rank() == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=dst, + group=get_tensor_model_parallel_group()) + if get_tensor_model_parallel_rank() == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + +def broadcast(input_: torch.Tensor, + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input tensor.""" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=src, group=group) + return input_ + + +def broadcast_object_list(obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list.""" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=src, group=group) + return obj_list + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[Any, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = "cpu" if value.is_cpu else "cuda" + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + `group` is used to broadcast the tensors, while `metadata_group` is used + to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, + dtypes). + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() + or torch.distributed.get_world_size(group=group) == 1): + return tensor_dict + + group = group or torch.distributed.group.WORLD + if is_hip(): + metadata_group = metadata_group or torch.distributed.group.WORLD + else: + metadata_group = metadata_group or get_cpu_world_group() + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + rank = torch.distributed.get_rank() + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` involves serialization and deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + torch.distributed.broadcast_object_list([metadata_list], + src=src, + group=metadata_group) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + recv_metadata_list = [None] + torch.distributed.broadcast_object_list(recv_metadata_list, + src=src, + group=metadata_group) + assert recv_metadata_list[0] is not None + tensor_dict = {} + async_handles = [] + for key, value in recv_metadata_list[0]: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm/distributed/device_communicators/__init__.py b/vllm/distributed/device_communicators/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000000000..a3902aecb3793 --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,311 @@ +from contextlib import contextmanager +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) +from vllm.distributed.parallel_state import ( + get_local_rank, get_tensor_model_parallel_cpu_group) +from vllm.logger import init_logger + +try: + import pynvml + + from vllm._C import custom_ar + + @contextmanager + def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + +except ImportError: + # For AMD GPUs + custom_ar = None + pynvml = None + + @contextmanager + def _nvml(): + try: + yield + finally: + pass + + +logger = init_logger(__name__) + + +@_nvml() +def _is_full_nvlink(device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, + so it works on real physical device ids. + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError as error: + logger.error( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped.", + exc_info=error) + return False + return True + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if not gpu_p2p_access_check(rank, i): + return False + return True + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, + max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if custom_ar is None: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + group = group or get_tensor_model_parallel_cpu_group() + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + full_nvlink = _is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + self.meta = torch.zeros(custom_ar.meta_size() + max_size, + dtype=torch.uint8, + device=self.device) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = full_nvlink + self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, + handles, offsets, rank, + self.full_nvlink) + self.register_buffer(self.buffer) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def _get_ipc_meta(self, inp: torch.Tensor): + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + custom_ar.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + custom_ar.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) + + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled: + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + if self.should_custom_ar(input): + return self.all_reduce_reg(input) + else: + if self.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if self.should_custom_ar(input): + return self.all_reduce_unreg(input) + + return None + + def close(self): + if not self.disabled and self._ptr: + custom_ar.dispose(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000000000..24ef3cb45b19d --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,186 @@ +import json +import os +import sys +import tempfile +import time +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@contextmanager +def mute_output(): + with open(os.devnull, "w") as f: + sys.stderr = f + sys.stdout = f + yield + + +def producer(i: int, + init_method: str, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + with mute_output(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + world_size=2, + rank=0, + ) + # produce a tensor in GPU i + data = torch.zeros((128, ), device=f"cuda:{i}") + # get the information to reconstruct the shared tensor + func, args = torch.multiprocessing.reductions.reduce_tensor(data) + args = list(args) + dist.broadcast_object_list([(func, args)], src=0) + dist.barrier() + torch.cuda.synchronize() + assert torch.all(data == 1).item() + + +def consumer(j: int, + init_method: str, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + with mute_output(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + world_size=2, + rank=1, + ) + torch.cuda.set_device(j) + recv = [None] + dist.broadcast_object_list(recv, src=0) + func: Callable + args: List + func, args = recv[0] # type: ignore + # `args[6]` is the device id + # by default pytorch will use `i` from the producer + # here we need to set it to `j` to test P2P access + args[6] = j + data = func(*args) + data += 1 + dist.barrier() + torch.cuda.synchronize() + assert torch.all(data == 1).item() + + +def can_actually_p2p(i, j): + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(i, j)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU i --> cuda context i --> tensor i --> process i + + We need to combine p2p and cuda IPC, so that: + GPU i --> cuda context i --> tensor i --> process i + |shared| + GPU j --> cuda context j --> tensor j --> process j + That is to say, process i creates a tensor in GPU i, passes IPC handle to + process j, and process j accesses the tensor in GPU j. Any operation on the + tensor in process j will be reflected in the tensor in process i, because + they are the same memory segment. + It is important to note that process j accesses the tensor in GPU j, not + GPU i. That's why we need p2p access. # noqa + """ + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the temp file is not the same across different calls + temp_path = tempfile.mktemp() + str(time.time()) + # create an empty file + with open(temp_path, "w"): + pass + init_method = f"file://{temp_path}" + + # make sure the processes are spawned + smp = mp.get_context("spawn") + pi = smp.Process(target=producer, + args=(i, init_method, cuda_visible_devices)) + pj = smp.Process(target=consumer, + args=(j, init_method, cuda_visible_devices)) + pi.start() + pj.start() + pi.join() + pj.join() + return pi.exitcode == 0 and pj.exitcode == 0 + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(i: int, j: int) -> bool: + """Check if GPU i can access GPU j.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{i}->{j}"] + + is_distributed = dist.is_initialized() + + num_dev = torch.cuda.device_count() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT + path = os.path.expanduser( + f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) + os.makedirs(os.path.dirname(path), exist_ok=True) + if ((not is_distributed or get_local_rank() == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache for in %s", path) + cache = {} + for _i in range(num_dev): + for _j in range(num_dev): + cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j) + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + cpu_world_group = get_cpu_world_group() + dist.barrier(cpu_world_group) + logger.info("reading GPU P2P access cache from %s", path) + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{i}->{j}"] + + +__all__ = ["gpu_p2p_access_check"] diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000000000..f5f1de0c71615 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl.py @@ -0,0 +1,185 @@ +from contextlib import contextmanager +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + assert dist.is_initialized() + group = get_cpu_world_group() if group is None else group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + self.group = group + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + self.stream = torch.cuda.Stream() + + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + self.stream.synchronize() + del data + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + + def all_reduce(self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, + tensor: torch.Tensor, + dst: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if dst is None: + dst = (self.rank + 1) % self.world_size + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, + tensor: torch.Tensor, + src: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src is None: + src = (self.rank - 1) % self.world_size + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..50d6719fbfe62 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,278 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/distributed/parallel_state.py similarity index 62% rename from vllm/model_executor/parallel_utils/parallel_state.py rename to vllm/distributed/parallel_state.py index 3bbfa1bd5443a..6d1b3b47c3dd7 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,17 +3,28 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import pynccl_utils +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True # Tensor model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None +_TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None +_PP_CPU_GROUP: Optional[ProcessGroup] = None +_PP_PYNCCL_COMMUNICATOR = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -37,16 +48,47 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None + +_LOCAL_RANK = -1 + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def get_pp_pynccl_communicator(): + global _PP_PYNCCL_COMMUNICATOR + return _PP_PYNCCL_COMMUNICATOR + + +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + +def get_tp_ca_communicator(): + global _TP_CA_COMMUNICATOR + return _TP_CA_COMMUNICATOR + + +def get_local_rank(): + global _LOCAL_RANK + return _LOCAL_RANK def init_distributed_environment( - world_size: int, - rank: int, - distributed_init_method: Optional[str] = None, + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", ): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " @@ -62,6 +104,26 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _LOCAL_RANK + _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + del data def initialize_model_parallel( @@ -111,27 +173,56 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR + assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group + _TP_DEVICE_GROUP = group + _TP_CPU_GROUP = cpu_group + + if tensor_model_parallel_size > 1 and not is_hip(): + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + + # Initialize a custom fast all-reduce implementation. + if _ENABLE_CUSTOM_ALL_REDUCE: + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + _TP_CA_COMMUNICATOR = CustomAllreduce( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + global _PP_DEVICE_GROUP, _PP_CPU_GROUP + global _PP_PYNCCL_COMMUNICATOR + global _PP_GLOBAL_RANKS + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_DEVICE_GROUP = group + _PP_CPU_GROUP = cpu_group + _PP_GLOBAL_RANKS = ranks + + if pipeline_model_parallel_size > 1 and not is_hip(): + _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_PP_CPU_GROUP, + device=_LOCAL_RANK, + ) def ensure_model_parallel_initialized( @@ -164,8 +255,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TENSOR_MODEL_PARALLEL_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -176,16 +266,30 @@ def get_cpu_world_group(): def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + assert _TP_DEVICE_GROUP is not None, ( "tensor model parallel group is not initialized") - return _TENSOR_MODEL_PARALLEL_GROUP + return _TP_DEVICE_GROUP + + +def get_tensor_model_parallel_cpu_group(): + """Get the tensor model parallel cpu group the caller rank belongs to.""" + assert _TP_CPU_GROUP is not None, ( + "tensor model parallel cpu group is not initialized") + return _TP_CPU_GROUP def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP + + +def get_pipeline_model_parallel_cpu_group(): + """Get the pipeline model parallel cpu group the caller rank belongs to.""" + assert _PP_CPU_GROUP is not None, ( + "pipeline model parallel cpu group is not initialized") + return _PP_CPU_GROUP def get_tensor_model_parallel_world_size(): @@ -222,81 +326,54 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): """Set the groups to none and destroy them.""" - global _TENSOR_MODEL_PARALLEL_GROUP - if _TENSOR_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) - _TENSOR_MODEL_PARALLEL_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() - - -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE + global _TP_DEVICE_GROUP + if _TP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_TP_DEVICE_GROUP) + _TP_DEVICE_GROUP = None + global _TP_CPU_GROUP + if _TP_CPU_GROUP: + torch.distributed.destroy_process_group(_TP_CPU_GROUP) + _TP_CPU_GROUP = None + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None + + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/distributed/utils.py similarity index 100% rename from vllm/model_executor/parallel_utils/utils.py rename to vllm/distributed/utils.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a24a404c05fce..9fe0d0bb0a301 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,20 +1,30 @@ import argparse import dataclasses +import json from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Tuple, Union -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" model: str + served_model_name: Optional[Union[List[str]]] = None tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' trust_remote_code: bool = False download_dir: Optional[str] = None @@ -25,11 +35,13 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + distributed_executor_backend: Optional[str] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -39,10 +51,12 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -50,26 +64,33 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 + model_loader_extra_config: Optional[dict] = None # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None + speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -80,72 +101,79 @@ def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" - # NOTE: If you update any of the arguments below, please also - # make sure to update docs/source/models/engine_args.rst - # Model arguments parser.add_argument( '--model', type=str, default='facebook/opt-125m', - help='name or path of the huggingface model to use') + help='Name or path of the huggingface model to use.') parser.add_argument( '--tokenizer', - type=str, + type=nullable_str, default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') + help='Name or path of the huggingface tokenizer to use.') + parser.add_argument( + '--skip-tokenizer-init', + action='store_true', + help='Skip initialization of tokenizer and detokenizer') parser.add_argument( '--revision', - type=str, + type=nullable_str, default=None, - help='the specific model version to use. It can be a branch ' + help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=str, + type=nullable_str, default=None, - help='the specific revision to use for the model code on ' + help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=str, + type=nullable_str, default=None, - help='the specific tokenizer version to use. It can be a branch ' + help='The specific tokenizer version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') + parser.add_argument( + '--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', - help='trust remote code from huggingface') + help='Trust remote code from huggingface.') parser.add_argument('--download-dir', - type=str, + type=nullable_str, default=EngineArgs.download_dir, - help='directory to download and load the weights, ' + help='Directory to download and load the weights, ' 'default to the default cache dir of ' - 'huggingface') + 'huggingface.') parser.add_argument( '--load-format', type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n') parser.add_argument( '--dtype', type=str, @@ -153,22 +181,25 @@ def add_cli_args( choices=[ 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' - 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria. ') + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', - type=str, + type=nullable_str, default=None, help='Path to the JSON file containing the KV cache ' 'scaling factors. This should generally be supplied, when ' @@ -176,52 +207,75 @@ def add_cli_args( 'default to 1.0, which may cause accuracy issues. ' 'FP8_E5M2 (without scaling) is only supported on cuda version' 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria. ') + 'supported for common inference criteria.') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + help='Model context length. If unspecified, will ' + 'be automatically derived from the model config.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc) by default. Currently support ' + 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/noamgat/lm-format-enforcer.' + ' Can be overridden per request via guided_decoding_backend' + ' parameter.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU ' - 'unless on ROCm where the default is torchrun') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp', 'torchrun'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, on CUDA this will be automatically set to "ray" if ' + 'installed or "mp" (multiprocessing) otherwise. On ROCm, this is ' + 'instead automatically set to torchrun.') + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') + help='Number of pipeline stages.') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') + help='Number of tensor parallel replicas.') parser.add_argument( '--max-parallel-loading-workers', type=int, default=EngineArgs.max_parallel_loading_workers, - help='load model sequentially in multiple batches, ' + help='Load model sequentially in multiple batches, ' 'to avoid RAM OOM when using tensor ' - 'parallel and large models') + 'parallel and large models.') parser.add_argument( '--ray-workers-use-nsight', action='store_true', - help='If specified, use nsight to profile ray workers') + help='If specified, use nsight to profile Ray workers.') # KV cache arguments parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 128], - help='token block size') + choices=[8, 16, 32], + help='Token block size for contiguous chunks of ' + 'tokens.') parser.add_argument('--enable-prefix-caching', action='store_true', - help='Enables automatic prefix caching') + help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', - help='Use BlockSpaceMangerV2') + help='Use BlockSpaceMangerV2.') parser.add_argument( '--num-lookahead-slots', type=int, @@ -234,18 +288,19 @@ def add_cli_args( parser.add_argument('--seed', type=int, default=EngineArgs.seed, - help='random seed') + help='Random seed for operations.') parser.add_argument('--swap-space', type=int, default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') + help='CPU swap space size (GiB) per GPU.') parser.add_argument( '--gpu-memory-utilization', type=float, default=EngineArgs.gpu_memory_utilization, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') + help='The fraction of GPU memory to be used for the model ' + 'executor, which can range from 0 to 1. For example, a value of ' + '0.5 would imply 50%% GPU memory utilization. If unspecified, ' + 'will use the default value of 0.9.') parser.add_argument( '--num-gpu-blocks-override', type=int, @@ -255,26 +310,26 @@ def add_cli_args( parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') + help='Maximum number of batched tokens per ' + 'iteration.') parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') + help='Maximum number of sequences per iteration.') parser.add_argument( '--max-logprobs', type=int, default=EngineArgs.max_logprobs, - help=('max number of log probs to return logprobs is specified in' - ' SamplingParams')) + help=('Max number of log probs to return logprobs is specified in' + ' SamplingParams.')) parser.add_argument('--disable-log-stats', action='store_true', - help='disable logging statistics') + help='Disable logging statistics.') # Quantization settings. parser.add_argument('--quantization', '-q', - type=str, - choices=['awq', 'gptq', 'squeezellm', None], + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' 'None, we first check the `quantization_config` ' @@ -282,6 +337,11 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -290,13 +350,21 @@ def add_cli_args( parser.add_argument('--max-context-len-to-capture', type=int, default=EngineArgs.max_context_len_to_capture, - help='maximum context length covered by CUDA ' + help='Maximum context length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seq-len-to-capture instead' + ')') + parser.add_argument('--max-seq-len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, - help='See ParallelConfig') + help='See ParallelConfig.') parser.add_argument('--tokenizer-pool-size', type=int, default=EngineArgs.tokenizer_pool_size, @@ -310,7 +378,7 @@ def add_cli_args( 'asynchronous tokenization. Ignored ' 'if tokenizer_pool_size is 0.') parser.add_argument('--tokenizer-pool-extra-config', - type=str, + type=nullable_str, default=EngineArgs.tokenizer_pool_extra_config, help='Extra config for tokenizer pool. ' 'This should be a JSON string that will be ' @@ -342,6 +410,17 @@ def add_cli_args( choices=['auto', 'float16', 'bfloat16', 'float32'], help=('Data type for LoRA. If auto, will default to ' 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) parser.add_argument( '--max-cpu-loras', type=int, @@ -349,6 +428,14 @@ def add_cli_args( help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -357,7 +444,7 @@ def add_cli_args( # Related to Vision-language models such as llava parser.add_argument( '--image-input-type', - type=str, + type=nullable_str, default=None, choices=[ t.name.lower() for t in VisionLanguageConfig.ImageInputType @@ -370,7 +457,7 @@ def add_cli_args( help=('Input id for image token.')) parser.add_argument( '--image-input-shape', - type=str, + type=nullable_str, default=None, help=('The biggest image input shape (worst for memory footprint) ' 'given an input type. Only used for vLLM\'s profile_run.')) @@ -387,28 +474,81 @@ def add_cli_args( 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens') + action='store_true', + help='If set, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens.') parser.add_argument( '--speculative-model', - type=str, - default=None, + type=nullable_str, + default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') parser.add_argument( '--num-speculative-tokens', type=int, - default=None, + default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' - 'the draft model in speculative decoding') + 'the draft model in speculative decoding.') + + parser.add_argument( + '--speculative-max-model-len', + type=int, + default=EngineArgs.speculative_max_model_len, + help='The maximum sequence length supported by the ' + 'draft model. Sequences over this length will skip ' + 'speculation.') + + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument('--model-loader-extra-config', + type=nullable_str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument. Noted that this name(s)" + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics" + "tag will take the first one.") + return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -419,11 +559,13 @@ def create_engine_config(self, ) -> EngineConfig: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, - self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, self.code_revision, - self.tokenizer_revision, self.max_model_len, self.quantization, + self.trust_remote_code, self.dtype, self.seed, self.revision, + self.code_revision, self.rope_scaling, self.tokenizer_revision, + self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, - self.max_context_len_to_capture, self.max_logprobs) + self.max_context_len_to_capture, self.max_seq_len_to_capture, + self.max_logprobs, self.disable_sliding_window, + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -431,14 +573,18 @@ def create_engine_config(self, ) -> EngineConfig: model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( - self.pipeline_parallel_size, self.tensor_parallel_size, - self.worker_use_ray, self.max_parallel_loading_workers, + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, self.disable_custom_all_reduce, TokenizerPoolConfig.create_config( self.tokenizer_pool_size, self.tokenizer_pool_type, self.tokenizer_pool_extra_config, - ), self.ray_workers_use_nsight) + ), + self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, @@ -446,6 +592,13 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + use_v2_block_manager=self.use_v2_block_manager, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ) scheduler_config = SchedulerConfig( @@ -458,15 +611,24 @@ def create_engine_config(self, ) -> EngineConfig: speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -483,6 +645,16 @@ def create_engine_config(self, ) -> EngineConfig: else: vision_language_config = None + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager): + raise ValueError( + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window.") + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -490,7 +662,9 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, - speculative_config=speculative_config) + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config) @dataclass @@ -501,20 +675,31 @@ class AsyncEngineArgs(EngineArgs): max_log_len: Optional[int] = None @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - parser = EngineArgs.add_cli_args(parser) + def add_cli_args(parser: argparse.ArgumentParser, + async_args_only: bool = False) -> argparse.ArgumentParser: + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) parser.add_argument('--engine-use-ray', action='store_true', - help='use Ray to start the LLM engine in a ' + help='Use Ray to start the LLM engine in a ' 'separate process as the server process.') parser.add_argument('--disable-log-requests', action='store_true', - help='disable logging requests') + help='Disable logging requests.') parser.add_argument('--max-log-len', type=int, default=None, - help='max number of prompt characters or prompt ' - 'ID numbers being printed in log. ' - 'Default: unlimited.') + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') return parser + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(argparse.ArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), + async_args_only=True) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f610495135121..db4d2849b3f0e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,5 +1,4 @@ import asyncio -import os import time from functools import partial from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, @@ -7,20 +6,23 @@ from transformers import PreTrainedTokenizer -from vllm.config import ModelConfig +import vllm.envs as envs +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster, ray +from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = int( - os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S class AsyncEngineDeadError(RuntimeError): @@ -47,15 +49,16 @@ def _raise_exception_on_finish( class AsyncStream: - """A stream of RequestOutputs for a request that can be - iterated over asynchronously.""" + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: self.request_id = request_id - self._queue = asyncio.Queue() + self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -71,7 +74,7 @@ def finished(self) -> bool: def __aiter__(self): return self - async def __anext__(self) -> RequestOutput: + async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: result = await self._queue.get() if isinstance(result, Exception): raise result @@ -108,7 +111,8 @@ def propagate_exception(self, self.abort_request(rid) def process_request_output(self, - request_output: RequestOutput, + request_output: Union[RequestOutput, + EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -117,7 +121,7 @@ def process_request_output(self, self._request_streams[request_id].put(request_output) if request_output.finished: if verbose: - logger.info(f"Finished request {request_id}.") + logger.info("Finished request %s.", request_id) self.abort_request(request_id) def process_exception(self, @@ -128,7 +132,7 @@ def process_exception(self, """Propagate an exception from the engine.""" self._request_streams[request_id].put(exception) if verbose: - logger.info(f"Finished request {request_id}.") + logger.info("Finished request %s.", request_id) self.abort_request(request_id) def add_request(self, request_id: str, @@ -151,7 +155,7 @@ def add_request(self, request_id: str, def abort_request(self, request_id: str, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: - logger.info(f"Aborted request {request_id}.") + logger.info("Aborted request %s.", request_id) self._finished_requests.put_nowait(request_id) @@ -196,7 +200,8 @@ def has_new_requests(self): class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" - async def step_async(self) -> List[RequestOutput]: + async def step_async( + self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -210,73 +215,97 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) output = await self.model_executor.execute_model_async( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + execute_model_req) else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + await self.model_executor.stop_remote_worker_execution_loop_async() - async def encode_request_async( + return request_outputs + + async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.process_model_inputs_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + self._add_processed_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. - - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -290,8 +319,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -310,15 +339,17 @@ def __init__(self, self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) - self.background_loop = None + self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded = None + self._background_loop_unshielded: Optional[asyncio.Task] = None self.start_engine_loop = start_engine_loop - self._request_tracker: Optional[RequestTracker] = None self._errored_with: Optional[BaseException] = None + # Lazy initialized fields + self._request_tracker: RequestTracker + @classmethod def from_engine_args( cls, @@ -329,23 +360,31 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) if engine_config.device_config.device_type == "neuron": - raise NotImplementedError("Neuron is not supported for " - "async engine yet.") - elif (engine_config.parallel_config.worker_use_ray - or engine_args.engine_use_ray): + from vllm.executor.neuron_executor import NeuronExecutorAsync + executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "cpu": + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with the CPU backend.") + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - engine_config.parallel_config.worker_use_ray, + distributed_executor_backend == "ray", engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, @@ -360,11 +399,13 @@ def from_engine_args( @property def is_running(self) -> bool: return (self.background_loop is not None + and self._background_loop_unshielded is not None and not self._background_loop_unshielded.done()) @property def is_stopped(self) -> bool: - return self.errored or (self.background_loop is not None + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None and self._background_loop_unshielded.done()) @property @@ -380,7 +421,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer(self) -> "PreTrainedTokenizer": if self.engine_use_ray: - return await self.engine.get_tokenizer.remote() + return await self.engine.get_tokenizer.remote() # type: ignore else: return self.engine.get_tokenizer() @@ -410,8 +451,8 @@ def _init_engine(self, *args, else: # FIXME(woosuk): This is a bit hacky. Be careful when changing the # order of the arguments. - cache_config = args[1] - parallel_config = args[2] + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] if parallel_config.tensor_parallel_size == 1: num_gpus = cache_config.gpu_memory_utilization else: @@ -433,7 +474,8 @@ async def engine_step(self) -> bool: # TODO: Maybe add add_request_batch to reduce Ray overhead try: if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) + await self.engine.add_request.remote( # type: ignore + **new_request) else: await self.engine.add_request_async(**new_request) except ValueError as e: @@ -448,7 +490,7 @@ async def engine_step(self) -> bool: await self._engine_abort(finished_requests) if self.engine_use_ray: - request_outputs = await self.engine.step.remote() + request_outputs = await self.engine.step.remote() # type: ignore else: request_outputs = await self.engine.step_async() @@ -461,7 +503,7 @@ async def engine_step(self) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: - await self.engine.abort_request.remote(request_ids) + await self.engine.abort_request.remote(request_ids) # type: ignore else: self.engine.abort_request(request_ids) @@ -488,27 +530,31 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] - logger.info(f"Received request {request_id}: " - f"prompt: {shortened_prompt!r}, " - f"sampling_params: {sampling_params}, " - f"prompt_token_ids: {shortened_token_ids}, " - f"lora_request: {lora_request}.") + shortened_token_ids = shortened_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) if not self.is_running: if self.start_engine_loop: @@ -524,38 +570,33 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await self.engine.encode_request_async.remote( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore + request_id=request_id, + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, - sampling_params=sampling_params, - prompt_token_ids=prompt_token_ids, + inputs=processed_inputs, + params=params, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -564,18 +605,16 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: - The output `RequestOutput` objects from the LLMEngine for the - request. + The output `RequestOutput` objects from the LLMEngine + for the request. Details: - If the engine is not running, start the background loop, @@ -620,25 +659,112 @@ async def generate( >>> # Process and return the final output >>> ... """ - # Preprocess the request. - arrival_time = time.time() - - try: - stream = await self.add_request( + async for output in self._process_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, - ) + ): + yield LLMEngine.validate_output(output, RequestOutput) + + async def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in self._process_request( + request_id, + inputs, + pooling_params, + lora_request=lora_request, + ): + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) + + async def _process_request( + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + *, + lora_request: Optional[LoRARequest] = None, + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + """Common logic to process requests with SamplingParams or + PoolingParams.""" + arrival_time = time.time() + + stream = await self.add_request( + request_id, + inputs, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + try: async for request_output in stream: yield request_output except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. self._abort(request_id) raise e @@ -675,13 +801,25 @@ def _abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: - return await self.engine.get_model_config.remote() + return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() - async def do_log_stats(self) -> None: + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_decoding_config.remote( # type: ignore + ) + else: + return self.engine.get_decoding_config() + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() + await self.engine.do_log_stats.remote( # type: ignore + scheduler_outputs, model_output) else: self.engine.do_log_stats() @@ -694,9 +832,9 @@ async def check_health(self) -> None: if self.engine_use_ray: try: - await self.engine.check_health.remote() + await self.engine.check_health.remote() # type: ignore except ray.exceptions.RayActorError as e: raise RuntimeError("Engine is dead.") from e else: await self.engine.check_health_async() - logger.debug(f"Health check took {time.perf_counter()-t}s") + logger.debug("Health check took %fs", time.perf_counter() - t) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7dfe6290c5ec4..11b800afb38a1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,23 +1,36 @@ import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, TypeVar, Union -from transformers import PreTrainedTokenizer +from transformers import GenerationConfig, PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, + SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import initialize_ray_cluster +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + PoolerOutput, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -30,6 +43,20 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 +def _load_generation_config_dict(model_config: ModelConfig): + try: + return GenerationConfig.from_pretrained( + model_config.model, + revision=model_config.revision, + ).to_diff_dict() + except OSError: + # Not found. + return {} + + +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -40,11 +67,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) Args: model_config: The configuration related to the LLM model. @@ -61,9 +88,60 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + def __init__( self, model_config: ModelConfig, @@ -71,35 +149,51 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, ) -> None: logger.info( - f"Initializing an LLM engine (v{vllm.__version__}) with config: " - f"model={model_config.model!r}, " - f"speculative_config={speculative_config!r}, " - f"tokenizer={model_config.tokenizer!r}, " - f"tokenizer_mode={model_config.tokenizer_mode}, " - f"revision={model_config.revision}, " - f"tokenizer_revision={model_config.tokenizer_revision}, " - f"trust_remote_code={model_config.trust_remote_code}, " - f"dtype={model_config.dtype}, " - f"max_seq_len={model_config.max_model_len}, " - f"download_dir={model_config.download_dir!r}, " - f"load_format={model_config.load_format}, " - f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " - f"disable_custom_all_reduce=" - f"{parallel_config.disable_custom_all_reduce}, " - f"quantization={model_config.quantization}, " - f"enforce_eager={model_config.enforce_eager}, " - f"kv_cache_dtype={cache_config.cache_dtype}, " - f"quantization_param_path={model_config.quantization_param_path}, " - f"device_config={device_config.device}, " - f"seed={model_config.seed})") + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "rope_scaling=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, seed=%d, served_model_name=%s)", + vllm.__version__, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.rope_scaling, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + model_config.seed, + model_config.served_model_name, + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config @@ -110,11 +204,20 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats - self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + else: + self.tokenizer = None + self.detokenizer = None + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) self.model_executor = executor_class( model_config=model_config, @@ -125,9 +228,11 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + load_config=load_config, ) - self._initialize_kv_caches() + if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -164,9 +269,10 @@ def __init__( parallel_config.disable_custom_all_reduce, }) - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of @@ -177,9 +283,25 @@ def __init__( if self.log_stats: self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.model)) + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + self.get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + self.get_tokenizer_for_seq, + ), + )) + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -191,8 +313,10 @@ def _initialize_kv_caches(self) -> None: if self.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info(f"Overriding {num_gpu_blocks=} with " - f"{num_gpu_blocks_override=}") + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) num_gpu_blocks = num_gpu_blocks_override self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -209,6 +333,8 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": @@ -217,16 +343,18 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor - elif engine_config.parallel_config.worker_use_torchrun: + elif distributed_executor_backend == "torchrun": from vllm.executor.torchrun_gpu_executor import TorchrunGPUExecutor executor_class = TorchrunGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + executor_class = MultiprocessingGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor @@ -244,14 +372,32 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self.get_tokenizer_group().get_lora_tokenizer( + sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -261,8 +407,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) + + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -272,29 +419,85 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def encode_request( + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def _add_processed_request( + self, + request_id: str, + processed_inputs: LLMInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) + + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def process_model_inputs( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -304,14 +507,14 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. - sampling_params: The sampling parameters for text generation. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -340,6 +543,30 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + if arrival_time is None: + arrival_time = time.time() + + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs if (sampling_params.logprobs and sampling_params.logprobs > max_logprobs) or ( @@ -347,35 +574,44 @@ def add_request( and sampling_params.prompt_logprobs > max_logprobs): raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") - if arrival_time is None: - arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens + # Add the eos token id into the sampling_params to support min_tokens # processing - sampling_params.eos_token_id = seq.eos_token_id + if seq.eos_token_id is not None: + sampling_params.all_stop_token_ids.add(seq.eos_token_id) + sampling_params.update_from_generation_config( + self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params) + return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -400,6 +636,10 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() @@ -408,256 +648,69 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _check_beam_search_early_stopping( + def _process_sequence_group_outputs( self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score - - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and seq_group.sampling_params.detokenize: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - parent_child_dict[sample.parent_seq_id].append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) - - # Non-beam search case - if not seq_group.sampling_params.use_beam_search: - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - return - - # Beam search case - # Select the child sequences to keep in the sequence group. - selected_child_seqs = [] - unselected_child_seqs = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty - - # Select the newly finished sequences with the highest scores - # to replace existing finished sequences. - # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] - all_finished_seqs = existing_finished_seqs + new_finished_seqs - # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - for seq, parent, is_new in all_finished_seqs[:beam_width]: - if is_new: - # A newly generated child sequence finishes and has a high - # score, so we will add it into the sequence group. - selected_child_seqs.append((seq, parent)) - for seq, parent, is_new in all_finished_seqs[beam_width:]: - if is_new: - # A newly generated child sequence finishes but has a low - # score, so we will not add it into the sequence group. - # Additionally, if this sequence is a continuation of a - # parent sequence, we will need remove the parent sequence - # from the sequence group. - unselected_child_seqs.append((seq, parent)) - else: - # An existing finished sequence has a low score, so we will - # remove it from the sequence group. - seq_group.remove(seq.seq_id) - - # select the top beam_width sequences from the running - # sequences for the next iteration to continue the beam - # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] - # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - - # Check if we can stop the beam search. - if len(running_child_seqs) == 0: - # No running sequences, stop the beam search. - stop_beam_search = True - elif len(all_finished_seqs) < beam_width: - # Not enough finished sequences, continue the beam search. - stop_beam_search = False - else: - # Check the early stopping criteria - best_running_seq = running_child_seqs[0][0] - current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) - - if stop_beam_search: - # Stop the beam search and remove all the running sequences from - # the sequence group. - unselected_child_seqs.extend(running_child_seqs) - else: - # Continue the beam search and select the top beam_width sequences - # to continue the beam search. - selected_child_seqs.extend(running_child_seqs[:beam_width]) - # The remaining running sequences will not be used in the next - # iteration. Again, if these sequences are continuations of - # parent sequences, we will need to remove the parent sequences - # from the sequence group. - unselected_child_seqs.extend(running_child_seqs[beam_width:]) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in selected_child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - for seq, parent in selected_child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - - # Remove the unselected parent sequences from the sequence group and - # free their memory in block manager. - for seq, parent in unselected_child_seqs: - if seq is parent: - # Remove the parent sequence if it is not selected for next - # iteration - seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return def _process_model_outputs( - self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + self, + output: GenericSequence[Union[SamplerOutput, PoolerOutput]], + scheduled_seq_groups: List[ScheduledSequenceGroup], + ignored_seq_groups: List[SequenceGroup], + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + """Apply the model output to the sequences in the scheduled seq groups. + + Returns RequestOutputs that can be returned to the client. + """ + now = time.time() + + # Organize outputs by [sequence group][step] instead of + # [step][sequence group]. + output_by_sequence_group = create_output_by_sequence_group( + output, num_seq_groups=len(scheduled_seq_groups)) + # Update the scheduled sequence groups with the model outputs. - scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs, seq_group_meta in zip( + scheduled_seq_groups, output_by_sequence_group, + seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + if self.model_config.embedding_mode: + self._process_sequence_group_outputs(seq_group, outputs) + continue + + self.output_processor.process_prompt_logprob(seq_group, outputs) + if seq_group_meta.do_sample: + self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[RequestOutput] = [] + request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) - for seq_group in scheduler_outputs.ignored_seq_groups: - request_output = RequestOutput.from_seq_group(seq_group) + for seq_group in ignored_seq_groups: + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) - - # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs - def step(self) -> List[RequestOutput]: + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -697,7 +750,7 @@ def step(self) -> List[RequestOutput]: >>> while True: >>> if example_inputs: >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> engine.add_request(str(req_id),prompt,sampling_params) >>> >>> # continue the request processing >>> request_outputs = engine.step() @@ -711,144 +764,208 @@ def step(self) -> List[RequestOutput]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) output = self.model_executor.execute_model( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + execute_model_req=execute_model_req) else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + self.model_executor.stop_remote_worker_execution_loop() + + return request_outputs - def do_log_stats(self) -> None: + def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs=None)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output)) + + def _get_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: - """Get Stats to be Logged to Prometheus.""" + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + """ now = time.time() - # KV Cache Usage in %. + # System State + # Scheduler State + num_running_sys = len(self.scheduler.running) + num_swapped_sys = len(self.scheduler.swapped) + num_waiting_sys = len(self.scheduler.waiting) + + # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 0. + if num_total_gpu is not None: + num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( + ) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage = 0. - if num_total_cpu > 0: + cpu_cache_usage_sys = 0. + if num_total_cpu is not None and num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) - cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) - - # Scheduler State - num_running = len(self.scheduler.running) - num_swapped = len(self.scheduler.swapped) - num_waiting = len(self.scheduler.waiting) - - # Iteration stats if we have scheduler output. - num_prompt_tokens = 0 - num_generation_tokens = 0 - time_to_first_tokens = [] - time_per_output_tokens = [] - time_e2e_requests = [] + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + best_of_requests: List[int] = [] + n_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: - prompt_run = scheduler_outputs.num_prefill_groups > 0 - - # Number of Tokens. - if prompt_run: - num_prompt_tokens = sum( - len(scheduled_seq_group.seq_group.prompt_token_ids) - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - num_generation_tokens = sum( - scheduled_seq_group.seq_group.num_seqs() - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - else: - num_generation_tokens = scheduler_outputs.num_batched_tokens - - # Latency Timings. - time_last_iters = [] - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + num_generation_tokens_from_prefill_groups = 0. + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group - # Time since last token. - # (n.b. updates seq_group.metrics.last_token_time) - time_last_iters.append(seq_group.get_last_latency(now)) - # Time since arrival for all finished requests. + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_latency(now) + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_latency(now) + time_per_output_tokens_iter.append(latency) + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. if seq_group.is_finished(): + # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) - time_to_first_tokens = time_last_iters if prompt_run else [] - time_per_output_tokens = [] if prompt_run else time_last_iters + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + if seq_group.sampling_params is not None: + best_of_requests.append( + seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and (model_output[0].spec_decode_worker_metrics + is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None return Stats( now=now, - num_running=num_running, - num_swapped=num_swapped, - num_waiting=num_waiting, - gpu_cache_usage=gpu_cache_usage, - cpu_cache_usage=cpu_cache_usage, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=num_generation_tokens, - time_to_first_tokens=time_to_first_tokens, - time_per_output_tokens=time_per_output_tokens, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency time_e2e_requests=time_e2e_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + best_of_requests=best_of_requests, + n_requests=n_requests, + finished_reason_requests=finished_reason_requests, ) - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: - return - - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] - def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 905db52a1912b..ae7ae144bc04f 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,8 @@ import time from dataclasses import dataclass -from typing import Dict, List +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Protocol, Union import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -8,6 +10,9 @@ from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + logger = init_logger(__name__) disable_created_metrics() @@ -18,8 +23,9 @@ # begin-metrics-definitions class Metrics: + labelname_finish_reason = "finished_reason" - def __init__(self, labelnames: List[str]): + def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors for collector in list(REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: @@ -31,18 +37,20 @@ def __init__(self, labelnames: List[str]): documentation='information of cache_config') # System stats + # Scheduler State self.gauge_scheduler_running = Gauge( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_swapped = Gauge( - name="vllm:num_requests_swapped", - documentation="Number of requests swapped to CPU.", - labelnames=labelnames) self.gauge_scheduler_waiting = Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) + self.gauge_scheduler_swapped = Gauge( + name="vllm:num_requests_swapped", + documentation="Number of requests swapped to CPU.", + labelnames=labelnames) + # KV Cache Usage in % self.gauge_gpu_cache_usage = Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -52,7 +60,11 @@ def __init__(self, labelnames: List[str]): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - # Raw stats from last model iteration + # Iteration stats + self.counter_num_preemption = Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -77,18 +89,51 @@ def __init__(self, labelnames: List[str]): 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 ]) - self.histogram_e2e_request_latency = Histogram( + + # Request stats + # Latency + self.histogram_e2e_time_request = Histogram( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) + # Metadata + self.histogram_num_prompt_tokens_request = Histogram( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = Histogram( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_best_of_request = Histogram( + name="vllm:request_params_best_of", + documentation="Histogram of the best_of request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_n_request = Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.counter_request_success = Counter( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) - # Legacy metrics + # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = Gauge( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) + # Deprecated in favor of vllm:generation_tokens_total self.gauge_avg_generation_throughput = Gauge( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", @@ -99,32 +144,75 @@ def __init__(self, labelnames: List[str]): # end-metrics-definitions +def build_1_2_5_buckets(max_value: int): + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values (1, 2, 5) until the value exceeds the specified maximum. + + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + mantissa_lst = [1, 2, 5] + exponent = 0 + buckets = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" now: float - # System stats. - num_running: int - num_waiting: int - num_swapped: int - gpu_cache_usage: float - cpu_cache_usage: float - - # Raw stats from last model iteration. - num_prompt_tokens: int - num_generation_tokens: int - time_to_first_tokens: List[float] - time_per_output_tokens: List[float] + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency time_e2e_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + best_of_requests: List[int] + n_requests: List[int] + finished_reason_requests: List[str] + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" - def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: + def __init__(self, local_interval: float, labels: Dict[str, str], + max_model_len: int) -> None: # Metadata for logging locally. - self.last_local_log = time.monotonic() + self.last_local_log = time.time() self.local_interval = local_interval # Tracked stats over current local logging interval. @@ -133,9 +221,10 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Prometheus metrics self.labels = labels - self.metrics = Metrics(labelnames=list(labels.keys())) + self.metrics = Metrics(labelnames=list(labels.keys()), + max_model_len=max_model_len) - def info(self, type: str, obj: object) -> None: + def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": self.metrics.info_cache_config.info(obj.metrics_info()) @@ -147,34 +236,68 @@ def _local_interval_elapsed(self, now: float) -> bool: return elapsed_time > self.local_interval def _log_prometheus(self, stats: Stats) -> None: - # Set system stat gauges. - self.metrics.gauge_scheduler_running.labels(**self.labels).set( - stats.num_running) - self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( - stats.num_swapped) - self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( - stats.num_waiting) - self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( - stats.gpu_cache_usage) - self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( - stats.cpu_cache_usage) - - # Add to token counters. - self.metrics.counter_prompt_tokens.labels(**self.labels).inc( - stats.num_prompt_tokens) - self.metrics.counter_generation_tokens.labels(**self.labels).inc( - stats.num_generation_tokens) - - # Observe request level latencies in histograms. - for ttft in stats.time_to_first_tokens: - self.metrics.histogram_time_to_first_token.labels( - **self.labels).observe(ttft) - for tpot in stats.time_per_output_tokens: - self.metrics.histogram_time_per_output_token.labels( - **self.labels).observe(tpot) - for e2e in stats.time_e2e_requests: - self.metrics.histogram_e2e_request_latency.labels( - **self.labels).observe(e2e) + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + + # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram(self.metrics.histogram_best_of_request, + stats.best_of_requests) + + def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter: Counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram: Histogram, + data: Union[List[int], List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: @@ -199,8 +322,8 @@ def log(self, stats: Stats) -> None: self._log_prometheus(stats) # Save tracked stats for token counters. - self.num_prompt_tokens.append(stats.num_prompt_tokens) - self.num_generation_tokens.append(stats.num_generation_tokens) + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. if self._local_interval_elapsed(stats.now): @@ -216,16 +339,37 @@ def log(self, stats: Stats) -> None: # Log to stdout. logger.info( - f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " - f"Avg generation throughput: " - f"{generation_throughput:.1f} tokens/s, " - f"Running: {stats.num_running} reqs, " - f"Swapped: {stats.num_swapped} reqs, " - f"Pending: {stats.num_waiting} reqs, " - f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " - f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now + + if stats.spec_decode_metrics is not None: + logger.info( + self._format_spec_decode_metrics_str( + stats.spec_decode_metrics)) + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000000000..9ddb6a3648b8c --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,76 @@ +from abc import ABC, abstractmethod +from typing import Callable, List + +from transformers import PreTrainedTokenizer + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor( + scheduler_config, + detokenizer, + scheduler, + seq_counter, + stop_checker, + ) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000000000..761e4ddd82714 --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,144 @@ +import functools +from typing import Callable, List + +from transformers import PreTrainedTokenizer + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + # TODO(sang): Prompt logprob currently not implemented in multi step + # workers. + self._log_prompt_logprob_unsupported_warning_once() + + @staticmethod + @functools.lru_cache() + def _log_prompt_logprob_unsupported_warning_once(): + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization, + including freeing finished sequences. It also handles cases where there + are tokens emitted after the EOS token. + """ + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + + assert seqs, "expected running sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + valid_samples = valid_samples[:remaining_tokens] + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + valid_samples = valid_samples[:i + 1] + break + + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): + seq.append_token_id( + token_id=output_token_id, + logprobs=output_logprob, + ) + + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + if seq.is_finished(): + break + + if seq.is_finished(): + self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000000000..44de1d7ec5607 --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,288 @@ +from typing import Dict, List, Tuple, Union + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__( + self, + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Counter, + stop_checker: StopChecker, + ): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0]) + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + assert len(outputs) == 1, ("Single step should only has 1 output.") + output = outputs[0] + prompt_logprobs = output.prompt_logprobs + if (prompt_logprobs is not None + and seq_group.sampling_params.detokenize and self.detokenizer): + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) + if not seq_group.prompt_logprobs: + # The first prompt token's logprob is None because it doesn't + # have tokens that are precedent. + seq_group.prompt_logprobs = [None] + seq_group.prompt_logprobs.extend(prompt_logprobs) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict: Dict[int, List[SequenceOutput]] = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id: int = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + if seq_group.sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + seq_group.sampling_params, + lora_req=seq_group.lora_request, + ) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=current_worst_seq.eos_token_id) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id)) + return current_worst_score >= highest_attainable_score diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000000000..96f0d1142611b --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,119 @@ +from typing import Callable, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], + PreTrainedTokenizer]): + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self._get_max_model_len(lora_req): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000000000..57cc33d911183 --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,21 @@ +from typing import List +from typing import Sequence as GenericSequence +from typing import Union + +from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput + + +def create_output_by_sequence_group( + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group: List[List[SequenceGroupOutput]] = [ + [] for _ in range(num_seq_groups) + ] + for step in outputs: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + return output_by_sequence_group diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c12..075de0b4efb2d 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case @@ -99,6 +100,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) @@ -109,7 +111,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: uvicorn.run(app, host=args.host, port=args.port, - log_level="debug", + log_level=args.log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5777e8179a1c1..6e971ae73f5d0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,17 +1,24 @@ -from typing import List, Optional, Union +from contextlib import contextmanager +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, + parse_and_batch_prompt) +from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs + +logger = init_logger(__name__) class LLM: @@ -23,15 +30,14 @@ class LLM: this class generates texts from the model, using an intelligent batching mechanism and efficient memory management. - NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. - Args: model: The name or path of a HuggingFace Transformers model. tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. tensor_parallel_size: The number of GPUs to use for distributed @@ -42,10 +48,11 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq" and "squeezellm". If None, we first check - the `quantization_config` attribute in the model config file. If - that is None, we assume the model weights are not quantized and use - `dtype` to determine the data type of the weights. + we support "awq", "gptq", "squeezellm", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a @@ -65,16 +72,38 @@ class LLM: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + def __init__( self, model: str, tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, trust_remote_code: bool = False, tensor_parallel_size: int = 1, dtype: str = "auto", @@ -85,8 +114,9 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = True, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -95,6 +125,7 @@ def __init__( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, dtype=dtype, @@ -106,6 +137,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) @@ -123,108 +155,415 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # LEGACY: single (prompt + optional token ids) def generate( self, - prompts: Optional[Union[str, List[str]]] = None, - sampling_params: Optional[SamplingParams] = None, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + def generate( + self, + prompts: List[str], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload + def generate( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. - NOTE: This class automatically batches the given prompts, considering + This class automatically batches the given prompts, considering the memory constraint. For the best performance, put all of your prompts into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: - A list of `RequestOutput` objects containing the generated - completions in the same order as the input prompts. + A list of `RequestOutput` objects containing the + generated completions in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if isinstance(prompts, str): - # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) + self._validate_and_add_requests( + inputs=inputs, + params=sampling_params, + lora_request=lora_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) - # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + @overload # LEGACY: single (prompt + optional token ids) + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + def encode( + self, + prompts: List[str], + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload + def encode( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def encode( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + generated embeddings in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + lora_request=lora_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + + # LEGACY + def _convert_v1_inputs( + self, + prompts: Optional[Union[str, List[str]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + multi_modal_data: Optional[MultiModalData], + ): + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] + if prompts is not None: + if prompt_token_ids is not None: + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) + else: + item = TextPrompt(prompt=prompts[i]) + else: + if prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + + return inputs + + def _validate_and_add_requests( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[LoRARequest], + ) -> None: + if isinstance(inputs, (str, dict)): + # Convert a single prompt to a list. + inputs = [inputs] + + num_requests = len(inputs) + + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + + # Add requests to the engine. + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, - sampling_params, - token_ids, + request_inputs, + params[i] if isinstance(params, Sequence) else params, lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, ) - return self._run_engine(use_tqdm) def _add_request( self, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]], + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + inputs, + params, + lora_request=lora_request) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine( + self, *, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm(total=num_requests, - desc="Processed prompts", - dynamic_ncols=True) + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=f"Generation Speed: {0:.2f} toks/s", + ) # Run the engine. - outputs: List[RequestOutput] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: outputs.append(output) if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_toks += sum( + len(stp.token_ids) for stp in output.outputs) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 32282bfd8d12b..97b35262329ee 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,9 +1,10 @@ import asyncio import importlib import inspect -import os +import re from contextlib import asynccontextmanager from http import HTTPStatus +from typing import Optional, Set import fastapi import uvicorn @@ -12,24 +13,33 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app +from starlette.routing import Mount import vllm +import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, ErrorResponse) + ChatCompletionResponse, + CompletionRequest, + EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext TIMEOUT_KEEP_ALIVE = 5 # seconds -openai_serving_chat: OpenAIServingChat = None -openai_serving_completion: OpenAIServingCompletion = None +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding + logger = init_logger(__name__) +_running_tasks: Set[asyncio.Task] = set() + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): @@ -40,7 +50,9 @@ async def _force_log(): await engine.do_log_stats() if not engine_args.disable_log_stats: - asyncio.create_task(_force_log()) + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) yield @@ -54,8 +66,10 @@ def parse_args(): # Add prometheus asgi middleware to route /metrics requests -metrics_app = make_asgi_app() -app.mount("/metrics", metrics_app) +route = Mount("/metrics", make_asgi_app()) +# Workaround for 307 Redirect for /metrics +route.path_regex = re.compile('^/metrics(?P.*)$') +app.routes.append(route) @app.exception_handler(RequestValidationError) @@ -95,6 +109,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") else: + assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) @@ -112,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) +@app.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await openai_serving_embedding.create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + return JSONResponse(content=generator.model_dump()) + + if __name__ == "__main__": args = parse_args() @@ -123,11 +149,13 @@ async def create_completion(request: CompletionRequest, raw_request: Request): allow_headers=args.allowed_headers, ) - if token := os.environ.get("VLLM_API_KEY") or args.api_key: + if token := envs.VLLM_API_KEY or args.api_key: @app.middleware("http") async def authentication(request: Request, call_next): root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: @@ -146,23 +174,41 @@ async def authentication(request: Request, call_next): raise ValueError(f"Invalid middleware {middleware}. " f"Must be a function or a class.") - logger.info(f"vLLM API server version {vllm.__version__}") - logger.info(f"args: {args}") + logger.info("vLLM API server version %s", vllm.__version__) + logger.info("args: %s", args) if args.served_model_name is not None: - served_model = args.served_model_name + served_model_names = args.served_model_name else: - served_model = args.model + served_model_names = [args.model] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model, + + event_loop: Optional[asyncio.AbstractEventLoop] + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + openai_serving_chat = OpenAIServingChat(engine, model_config, + served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model, args.lora_modules) - + engine, model_config, served_model_names, args.lora_modules) + openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, + served_model_names) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index cc71931b97955..4c0cb1e4f3e49 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -8,8 +8,8 @@ import json import ssl -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.serving_engine import LoRA +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.openai.serving_engine import LoRAModulePath class LoRAParserAction(argparse.Action): @@ -18,14 +18,17 @@ def __call__(self, parser, namespace, values, option_string=None): lora_list = [] for item in values: name, path = item.split('=') - lora_list.append(LoRA(name, path)) + lora_list.append(LoRAModulePath(name, path)) setattr(namespace, self.dest, lora_list) def make_arg_parser(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--uvicorn-log-level", @@ -49,45 +52,39 @@ def make_arg_parser(): default=["*"], help="allowed headers") parser.add_argument("--api-key", - type=str, + type=nullable_str, default=None, help="If provided, the server will require this key " "to be presented in the header.") - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. If not " - "specified, the model name will be the same as " - "the huggingface name.") parser.add_argument( "--lora-modules", - type=str, + type=nullable_str, default=None, nargs='+', action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") parser.add_argument("--chat-template", - type=str, + type=nullable_str, default=None, help="The file path to the chat template, " "or the template in single-line form " "for the specified model") parser.add_argument("--response-role", - type=str, + type=nullable_str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL key file") parser.add_argument("--ssl-certfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL cert file") parser.add_argument("--ssl-ca-certs", - type=str, + type=nullable_str, default=None, help="The CA certificates file") parser.add_argument( @@ -98,12 +95,12 @@ def make_arg_parser(): ) parser.add_argument( "--root-path", - type=str, + type=nullable_str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") parser.add_argument( "--middleware", - type=str, + type=nullable_str, action="append", default=[], help="Additional ASGI middleware to apply to the app. " diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc4..e380212a4d76b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,16 +1,58 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union +import openai.types.chat import torch -from pydantic import BaseModel, Field, conint, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Annotated, Required, TypedDict +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -class ErrorResponse(BaseModel): +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam = Union[ + openai.types.chat.ChatCompletionContentPartParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + +ChatCompletionMessageParam = Union[ + openai.types.chat.ChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): object: str = "error" message: str type: str @@ -18,7 +60,7 @@ class ErrorResponse(BaseModel): code: int -class ModelPermission(BaseModel): +class ModelPermission(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") object: str = "model_permission" created: int = Field(default_factory=lambda: int(time.time())) @@ -30,10 +72,10 @@ class ModelPermission(BaseModel): allow_fine_tuning: bool = False organization: str = "*" group: Optional[str] = None - is_blocking: str = False + is_blocking: bool = False -class ModelCard(BaseModel): +class ModelCard(OpenAIBaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -43,36 +85,38 @@ class ModelCard(BaseModel): permission: List[ModelPermission] = Field(default_factory=list) -class ModelList(BaseModel): +class ModelList(OpenAIBaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) -class UsageInfo(BaseModel): +class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 -class ResponseFormat(BaseModel): +class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" - type: str = Literal["text", "json_object"] + type: Literal["text", "json_object"] -class ChatCompletionRequest(BaseModel): +class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[Dict[str, str]] + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False temperature: Optional[float] = 0.7 @@ -133,12 +177,22 @@ class ChatCompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-chat-completion-extra-params def to_sampling_params(self) -> SamplingParams: - if self.logprobs and not self.top_logprobs: - raise ValueError("Top logprobs must be set when logprobs is.") + # We now allow logprobs being true without top_logrobs. logits_processors = None if self.logit_bias: @@ -146,6 +200,7 @@ def to_sampling_params(self) -> SamplingParams: def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + assert self.logit_bias is not None for token_id, bias in self.logit_bias.items(): # Clamp the bias between -100 and 100 per OpenAI API spec bias = min(100, max(-100, bias)) @@ -195,8 +250,21 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif not 0 <= data["top_logprobs"] <= 20: + raise ValueError( + "`top_logprobs` must be a value in the interval [0, 20].") + return data -class CompletionRequest(BaseModel): + +class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: str @@ -207,9 +275,11 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[int] = None max_tokens: Optional[int] = 16 - n: Optional[int] = 1 + n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False suffix: Optional[str] = None @@ -229,7 +299,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -265,6 +335,17 @@ class CompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-completion-extra-params @@ -277,6 +358,7 @@ def to_sampling_params(self): def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + assert self.logit_bias is not None for token_id, bias in self.logit_bias.items(): # Clamp the bias between -100 and 100 per OpenAI API spec bias = min(100, max(-100, bias)) @@ -327,20 +409,47 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not 0 <= data["logprobs"] <= 5: + raise ValueError(("if passed, `logprobs` must be a value", + " in the interval [0, 5].")) + return data + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None -class LogProbs(BaseModel): + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None -class CompletionResponseChoice(BaseModel): +class CompletionResponseChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -349,7 +458,7 @@ class CompletionResponseChoice(BaseModel): ) -class CompletionResponse(BaseModel): +class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -358,12 +467,12 @@ class CompletionResponse(BaseModel): usage: UsageInfo -class CompletionResponseStreamChoice(BaseModel): +class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -372,7 +481,7 @@ class CompletionResponseStreamChoice(BaseModel): ) -class CompletionStreamResponse(BaseModel): +class CompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -381,20 +490,49 @@ class CompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = Field(default=None) -class ChatMessage(BaseModel): +class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: List[float] + + +class EmbeddingResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class ChatMessage(OpenAIBaseModel): role: str content: str -class ChatCompletionResponseChoice(BaseModel): +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionResponse(BaseModel): +class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -403,23 +541,64 @@ class ChatCompletionResponse(BaseModel): usage: UsageInfo -class DeltaMessage(BaseModel): +class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None -class ChatCompletionResponseStreamChoice(BaseModel): +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionStreamResponse(BaseModel): +class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameteters of the request. + body: Union[ChatCompletionRequest, ] + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[ChatCompletionResponse] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000000000..731f4f4a4028a --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,141 @@ +import argparse +import asyncio +import sys +from io import StringIO + +import aiohttp + +import vllm +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + ChatCompletionResponse) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, "r") as f: + return f.read() + + +async def write_file(path_or_url: str, data: str) -> None: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.put(path_or_url, data=data.encode("utf-8")): + pass + else: + # We should make this async, but as long as this is always run as a + # standalone program, blocking the event loop won't effect performance + # in this particular case. + with open(path_or_url, "w") as f: + f.write(data) + + +async def run_request(chat_serving: OpenAIServingChat, + request: BatchRequestInput) -> BatchRequestOutput: + chat_request = request.body + chat_response = await chat_serving.create_chat_completion(chat_request) + if isinstance(chat_response, ChatCompletionResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=chat_response, + error=None, + ) + else: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=None, + error=chat_response, + ) + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) + + # When using single vLLM without engine_use_ray + model_config = await engine.get_model_config() + + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + ) + + # Submit all requests in the file to the engine "concurrently". + response_futures = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + request = BatchRequestInput.model_validate_json(request_json) + response_futures.append(run_request(openai_serving_chat, request)) + + responses = await asyncio.gather(*response_futures) + + output_buffer = StringIO() + for response in responses: + print(response.model_dump_json(), file=output_buffer) + + output_buffer.seek(0) + await write_file(args.output_file, output_buffer.read().strip()) + + # Temporary workaround for https://github.com/vllm-project/vllm/issues/4789 + sys.exit(0) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM API server version %s", vllm.__version__) + logger.info("args: %s", args) + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3cb614..cc5b896e0e56c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,41 +1,131 @@ import codecs import time -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from dataclasses import dataclass +from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, + Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final from fastapi import Request +from openai.types.chat import ChatCompletionContentPartTextParam +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionContentPartParam, ChatCompletionLogProb, + ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import random_uuid logger = init_logger(__name__) +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + + class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, - served_model: str, + model_config: ModelConfig, + served_model_names: List[str], response_role: str, - lora_modules: Optional[List[LoRA]] = None, - chat_template=None): + lora_modules: Optional[List[LoRAModulePath]] = None, + chat_template: Optional[str] = None): super().__init__(engine=engine, - served_model=served_model, + model_config=model_config, + served_model_names=served_model_names, lora_modules=lora_modules) + self.response_role = response_role self._load_chat_template(chat_template) + def _load_chat_template(self, chat_template: Optional[str]): + tokenizer = self.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning( + "No chat template provided. Chat API will not work.") + + def _parse_chat_message_content_parts( + self, + role: str, + parts: Iterable[ChatCompletionContentPartParam], + ) -> ChatMessageParseResult: + texts: List[str] = [] + + for _, part in enumerate(parts): + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] + + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + messages = [ConversationMessage(role=role, content="\n".join(texts))] + + return ChatMessageParseResult(messages=messages) + + def _parse_chat_message_content( + self, + message: ChatCompletionMessageParam, + ) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages) + + return self._parse_chat_message_content_parts(role, content) + async def create_chat_completion( - self, request: ChatCompletionRequest, raw_request: Request + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: """Completion API similar to OpenAI's API. @@ -52,24 +142,36 @@ async def create_chat_completion( return error_check_ret try: + conversation: List[ConversationMessage] = [] + + for msg in request.messages: + parsed_msg = self._parse_chat_message_content(msg) + + conversation.extend(parsed_msg.messages) + prompt = self.tokenizer.apply_chat_template( - conversation=request.messages, + conversation=conversation, tokenize=False, - add_generation_prompt=request.add_generation_prompt) + add_generation_prompt=request.add_generation_prompt, + ) except Exception as e: - logger.error( - f"Error in applying chat template from request: {str(e)}") + logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) + # Tokenize/detokenize depending on prompt format (string/token list) + prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + request, prompt=prompt, add_special_tokens=False) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] @@ -78,17 +180,24 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id) + request, result_generator, request_id, conversation) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + request, raw_request, result_generator, request_id, + conversation) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -101,21 +210,21 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str - ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - - model_name = request.model + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> AsyncGenerator[str, None]: + model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True # Send response for each token for each request.n (index) + assert request.n is not None previous_texts = [""] * request.n previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n try: async for res in result_generator: - res: RequestOutput # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -142,12 +251,10 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, - list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get( + "role") == role: + last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(request.n): @@ -180,11 +287,10 @@ async def chat_completion_stream_generator( previous_num_tokens[i]:] if output.logprobs else None if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None @@ -242,16 +348,17 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + self, request: ChatCompletionRequest, raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = request.model + model_name = self.served_model_names[0] created_time = int(time.time()) - final_res: RequestOutput = None + final_res: Optional[RequestOutput] = None async for res in result_generator: - if await raw_request.is_disconnected(): + if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(request_id) return self.create_error_response("Client disconnected") @@ -266,10 +373,10 @@ async def chat_completion_full_generator( top_logprobs = output.logprobs if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None @@ -279,17 +386,14 @@ async def chat_completion_full_generator( message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, - stop_reason=output.stop_reason, - ) + stop_reason=output.stop_reason) choices.append(choice_data) if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] for choice in choices: full_message = last_msg_content + choice.message.content @@ -313,24 +417,50 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template): - if chat_template is not None: - try: - with open(chat_template, "r") as f: - self.tokenizer.chat_template = f.read() - except OSError: - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - self.tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info( - f"Using supplied chat template:\n{self.tokenizer.chat_template}" - ) - elif self.tokenizer.chat_template is not None: - logger.info( - f"Using default chat template:\n{self.tokenizer.chat_template}" - ) - else: - logger.warning( - "No chat template provided. Chat API will not work.") + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], + top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb( + token=self._get_decoded_token(p[1], p[0]), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + self._get_decoded_token(p[1], + p[0]).encode("utf-8", + errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + + logprobs_content = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self.tokenizer.decode(token_id), + bytes=list( + self.tokenizer.decode(token_id).encode( + "utf-8", errors="replace")))) + else: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=step_top_logprobs[token_id].decoded_token, + logprob=max(step_top_logprobs[token_id].logprob, + -9999.0), + bytes=list( + step_top_logprobs[token_id].decoded_token.encode( + "utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs))) + + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a9225fefb..2fb122edaf98a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,30 +1,37 @@ -import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional, Tuple) + Optional) +from typing import Sequence as GenericSequence +from typing import Tuple from fastapi import Request +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import (CompletionRequest, +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - LogProbs, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing + UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput -from vllm.utils import random_uuid +from vllm.sequence import Logprob +from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) TypeTokenIDs = List[int] TypeTopLogProbs = List[Optional[Dict[int, float]]] TypeCreateLogProbsFn = Callable[ - [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] def parse_prompt_format(prompt) -> Tuple[bool, list]: @@ -50,49 +57,14 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def merge_async_iterators(*iterators): - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - queue = asyncio.Queue() - - finished = [False] * len(iterators) - - async def producer(i, iterator): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item - await asyncio.gather(*_tasks) - - return consumer() - - class OpenAIServingCompletion(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, - served_model: str, - lora_modules: Optional[List[LoRA]] = None): + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]]): super().__init__(engine=engine, - served_model=served_model, + model_config=model_config, + served_model_names=served_model_names, lora_modules=lora_modules) async def create_completion(self, request: CompletionRequest, @@ -115,18 +87,22 @@ async def create_completion(self, request: CompletionRequest, return self.create_error_response( "suffix is not currently supported") - model_name = request.model + model_name = self.served_model_names[0] request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] @@ -136,24 +112,30 @@ async def create_completion(self, request: CompletionRequest, for i, prompt in enumerate(prompts): if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) else: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) + prompt_ids, prompt_text = prompt_formats + + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) - generators.append( - self.engine.generate(prompt, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=input_ids, - lora_request=lora_request)) + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -179,7 +161,7 @@ async def create_completion(self, request: CompletionRequest, num_prompts=len(prompts)) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): @@ -216,6 +198,7 @@ async def completion_stream_generator( model_name: str, num_prompts: int, ) -> AsyncGenerator[str, None]: + assert request.n is not None previous_texts = [""] * request.n * num_prompts previous_num_tokens = [0] * request.n * num_prompts has_echoed = [False] * request.n * num_prompts @@ -233,6 +216,7 @@ async def completion_stream_generator( # TODO(simon): optimize the performance by avoiding full # text O(n^2) sending. + assert request.max_tokens is not None if request.echo and request.max_tokens == 0: # only return the prompt delta_text = res.prompt @@ -257,7 +241,7 @@ async def completion_stream_generator( i]:] if output.logprobs else None if request.logprobs is not None: - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -310,7 +294,7 @@ def request_output_to_completion_response( created_time: int, model_name: str, ) -> CompletionResponse: - choices = [] + choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 for final_res in final_res_batch: @@ -320,13 +304,15 @@ def request_output_to_completion_response( prompt_text = final_res.prompt for output in final_res.outputs: + assert request.max_tokens is not None if request.echo and request.max_tokens == 0: token_ids = prompt_token_ids top_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs + top_logprobs = (prompt_logprobs + output.logprobs + if request.logprobs else None) output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -334,7 +320,10 @@ def request_output_to_completion_response( output_text = output.text if request.logprobs is not None: - logprobs = self._create_logprobs( + assert top_logprobs is not None, ( + "top_logprobs must be provided when logprobs " + "is requested") + logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -368,3 +357,59 @@ def request_output_to_completion_response( choices=choices, usage=usage, ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: int, + initial_text_offset: int = 0, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + + last_token_len = 0 + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + token = self._get_decoded_token(step_top_logprobs[token_id], + token_id) + token_logprob = max(step_top_logprobs[token_id].logprob, + -9999.0) + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], top_lp[0]): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000000000..5a3448de3d7a4 --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,144 @@ +import time +from typing import AsyncIterator, List, Optional, Tuple + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_completion import parse_prompt_format +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] + + +def request_output_to_embedding_response( + final_res_batch: List[EmbeddingRequestOutput], + request_id: str, + created_time: int, + model_name: str, +) -> EmbeddingResponse: + data = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + + embedding_data = EmbeddingResponseData( + index=idx, embedding=final_res.outputs.embedding) + data.append(embedding_data) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str]): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + self._check_embedding_mode(model_config.embedding_mode) + + async def create_embedding(self, request: EmbeddingRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # Return error for unsupported features. + if request.encoding_format == "base64": + return self.create_error_response( + "base64 encoding is not currently supported") + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + prompt_is_tokens, prompts = parse_prompt_format(request.input) + pooling_params = request.to_pooling_params() + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + prompt_ids, prompt_text = prompt_formats + + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def _check_embedding_mode(self, embedding_mode: bool): + if not embedding_mode: + logger.warning( + "embedding_mode is False. Embedding API will not work.") + else: + logger.info("Activating the server engine with embedding enabled.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c0251e..066acdf1c019a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,15 +1,17 @@ -import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, ErrorResponse, - LogProbs, ModelCard, ModelList, + CompletionRequest, + EmbeddingRequest, ErrorResponse, + ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -20,19 +22,31 @@ @dataclass -class LoRA: +class LoRAModulePath: name: str local_path: str class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, - served_model: str, - lora_modules=Optional[List[LoRA]]): + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]]): + super().__init__() + self.engine = engine - self.served_model = served_model + self.max_model_len = model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + truncation_side="left") + + self.served_model_names = served_model_names + if lora_modules is None: self.lora_requests = [] else: @@ -44,84 +58,23 @@ def __init__(self, ) for i, lora in enumerate(lora_modules, start=1) ] - self.max_model_len = 0 - self.tokenizer = None - - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - event_loop.create_task(self._post_init()) - else: - # When using single vLLM without engine_use_ray - asyncio.run(self._post_init()) - - async def _post_init(self): - engine_model_config = await self.engine.get_model_config() - self.max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - self.tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code, - truncation_side="left") - async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=self.served_model, - root=self.served_model, + ModelCard(id=served_model_name, + root=self.served_model_names[0], permission=[ModelPermission()]) + for served_model_name in self.served_model_names ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model, + root=self.served_model_names[0], permission=[ModelPermission()]) for lora in self.lora_requests ] model_cards.extend(lora_cards) return ModelList(data=model_cards) - def _create_logprobs( - self, - token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, - ) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id].logprob - else: - token_logprob = None - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - p.decoded_token: p.logprob - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - return logprobs - def create_error_response( self, message: str, @@ -144,32 +97,40 @@ def create_streaming_error_response( }) return json_str - async def _check_model(self, request) -> Optional[ErrorResponse]: - if request.model == self.served_model: - return + async def _check_model( + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] + ) -> Optional[ErrorResponse]: + if request.model in self.served_model_names: + return None if request.model in [lora.lora_name for lora in self.lora_requests]: - return + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora(self, request) -> Optional[LoRARequest]: - if request.model == self.served_model: - return + def _maybe_get_lora( + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] + ) -> Optional[LoRARequest]: + if request.model in self.served_model_names: + return None for lora in self.lora_requests: if request.model == lora.lora_name: return lora # if _check_model has been called earlier, this will be unreachable - raise ValueError("The model `{request.model}` does not exist.") + raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self, - request: Union[ChatCompletionRequest, CompletionRequest], + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None - ) -> List[int]: + truncate_prompt_tokens: Optional[Annotated[int, + Field(ge=1)]] = None, + add_special_tokens: bool = True) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -177,19 +138,46 @@ def _validate_prompt_and_tokenize( "Only one of prompt or prompt_ids should be provided.") if prompt_ids is None: - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { - "truncation": True, - "max_length": truncate_prompt_tokens, + # When using OpenAIServingChat for chat completions, the + # special tokens (e.g., BOS) have already been added by the + # chat template. Therefore, we do not need to add them again. + # Set add_special_tokens to False to avoid adding the BOS tokens + # again. + tokenizer_kwargs: Dict[str, Any] = { + "add_special_tokens": add_special_tokens } + if truncate_prompt_tokens is not None: + tokenizer_kwargs.update({ + "truncation": True, + "max_length": truncate_prompt_tokens, + }) input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids elif truncate_prompt_tokens is not None: input_ids = prompt_ids[-truncate_prompt_tokens:] else: input_ids = prompt_ids + input_text = prompt if prompt is not None else self.tokenizer.decode( + prompt_ids) token_num = len(input_ids) + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, EmbeddingRequest): + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.", ) + return input_ids, input_text + if request.max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.", ) request.max_tokens = self.max_model_len - token_num if token_num + request.max_tokens > self.max_model_len: @@ -201,4 +189,9 @@ def _validate_prompt_and_tokenize( f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids + return input_ids, input_text + + def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: + if logprob.decoded_token is not None: + return logprob.decoded_token + return self.tokenizer.decode(token_id) diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000000000..35421b9026f1e --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,235 @@ +import os +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None + VLLM_USE_MODELSCOPE: bool = False + VLLM_INSTANCE_ID: Optional[str] = None + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = True + RANK: int = 0 + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_CONFIG_ROOT: str = "" + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_BUILD_WITH_NEURON: bool = False + VLLM_USE_PRECOMPILED: bool = False + VLLM_INSTALL_PUNICA_KERNELS: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: Dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will build with Neuron support + "VLLM_BUILD_WITH_NEURON": + lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), + + # If set, vllm will install Punica kernels + "VLLM_INSTALL_PUNICA_KERNELS": + lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for VLLM configuration files + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv( + "XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"), + + # ================== Runtime Env Vars ================== + + # used in distributed environment to determine the master address + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + + # used in distributed environment to manually set the communication port + # '0' is used to make mypy happy + 'VLLM_PORT': + lambda: int(os.getenv('VLLM_PORT', '0')) + if 'VLLM_PORT' in os.environ else None, + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Instance id represents an instance of the VLLM. All processes in the same + # instance should have the same instance id. + "VLLM_INSTANCE_ID": + lambda: os.environ.get("VLLM_INSTANCE_ID", None), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in + ("true", "1")), + + # rank of the process in the distributed setting, used to determine + # the driver worker + "RANK": + lambda: int(os.environ.get("RANK", "0")), + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + + # API key for VLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"), + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # CPU key-value cache space + # default is 4GB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), +} + +# end-env-vars-definition + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0ed..a2212459f034e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,37 +1,28 @@ -import os -from typing import Dict, List, Optional +from typing import List, Set, Tuple import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.executor.executor_base import ExecutorBase +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) logger = init_logger(__name__) class CPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: - assert device_config.device_type == "cpu" - assert lora_config is None, "cpu backend doesn't support LoRA" - model_config = _verify_and_get_model_config(model_config) - cache_config = _verify_and_get_cache_config(cache_config) - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + assert self.lora_config is None, "cpu backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -50,17 +41,19 @@ def _init_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ @@ -73,20 +66,16 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - logger.info(f"# CPU blocks: {num_cpu_blocks}") + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -95,7 +84,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -104,6 +93,21 @@ def check_health(self) -> None: return +class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: if config.dtype == torch.float16: logger.warning("float16 is not supported on CPU, casting to bfloat16.") @@ -116,14 +120,22 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: _GB = 1 << 30 if config.enable_prefix_caching: logger.warning("Prefix caching is not supported on CPU, disable it.") config.enable_prefix_caching = False - kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") - kv_cache_space = int(kv_cache_space_str) + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: if kv_cache_space == 0: diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py new file mode 100644 index 0000000000000..f7c608af1ad39 --- /dev/null +++ b/vllm/executor/distributed_gpu_executor.py @@ -0,0 +1,197 @@ +import asyncio +from abc import abstractmethod +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + + @abstractmethod + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + raise NotImplementedError + + +class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + @abstractmethod + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c18edd75d7a4d..4d01939c2e38b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput class ExecutorBase(ABC): @@ -16,7 +16,6 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - @abstractmethod def __init__( self, model_config: ModelConfig, @@ -24,14 +23,29 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: - raise NotImplementedError + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = speculative_config + + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. @@ -39,7 +53,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]: ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. @@ -54,14 +68,16 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences.""" raise NotImplementedError + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError @@ -71,7 +87,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @abstractmethod @@ -80,22 +96,28 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + class ExecutorAsyncBase(ExecutorBase): @abstractmethod async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" - raise NotImplementedError + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c5..3ad201f4757ec 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,72 +1,74 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Set, Tuple, Union -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class GPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - - assert (not speculative_config - ), "Speculative decoding not yet supported for GPU backend" - - # Instantiate the worker and load the model to GPU. - self._init_worker() - - def _init_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, - local_rank=0, - rank=0, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - is_driver_worker=True, + speculative_config=self.speculative_config, + is_driver_worker=rank == 0, ) - self.driver_worker.init_device() - self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + + if self.speculative_config is None: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + else: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ @@ -78,22 +80,15 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: # NOTE: This is logged in the executor because there can be >1 worker # with other executors. We could log in the engine level, but work # remains to abstract away the device for non-GPU configurations. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> List[Union[SamplerOutput, PoolerOutput]]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -104,7 +99,7 @@ def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -117,19 +112,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + execute_model_req: ExecuteModelRequest, + ) -> List[Union[SamplerOutput, PoolerOutput]]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output - - async def check_health_async(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 0000000000000..8fa54454907b5 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,155 @@ +import asyncio +import os +from functools import partial +from typing import Any, List, Optional + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +logger = init_logger(__name__) + + +class MultiprocessingGPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( + map(str, range(world_size)))) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model( + execute_model_req=execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if async_run_remote_workers_only: + # Just return futures + return worker_outputs + + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + +class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, + DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_model(execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000000000..62887533f5c27 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,263 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + +mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + "Exception in worker %s while processing method %s: %s, %s", + process_name, method, e, tb) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 57436a85cfa27..e7f0e887921b7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,35 +1,20 @@ -from typing import Dict, List, Optional +from typing import List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from vllm.executor.executor_base import ExecutorBase +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import make_async logger = init_logger(__name__) class NeuronExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - assert lora_config is None, "LoRA is not supported for Neuron backend." - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. @@ -43,11 +28,12 @@ def _init_worker(self): self.parallel_config, self.scheduler_config, self.device_config, + self.cache_config, ) self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ @@ -59,17 +45,18 @@ def initialize_cache(self, num_gpu_blocks: int, """ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} - and blocks_to_copy == {}), ( + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (execute_model_req.blocks_to_swap_in == {} + and execute_model_req.blocks_to_swap_out == {} + and execute_model_req.blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -78,10 +65,27 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: # NeuronExecutor will always be healthy as long as # it's running. return + + +class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) + return output + + async def check_health_async(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6c0ccd7e64c90..bed356d1b6e58 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,20 +1,18 @@ import asyncio -import copy import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from itertools import islice, repeat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from vllm.engine.ray_utils import RayWorkerVllm, ray -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +import vllm.envs as envs +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async, set_cuda_visible_devices) + get_vllm_instance_id, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -24,36 +22,13 @@ logger = init_logger(__name__) -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(DistributedGPUExecutor): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - assert (not speculative_config - ), "Speculative decoding not yet supported for RayGPU backend." - - assert self.parallel_config.worker_use_ray + def _init_executor(self) -> None: + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -67,6 +42,23 @@ def __init__( self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.extra_execute_model_run_workers_kwargs[ + "use_ray_compiled_dag"] = True + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): @@ -79,9 +71,13 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerVllm = None + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerVllm] = [] + self.workers: List[RayWorkerWrapper] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) # Create the workers. driver_ip = get_ip() @@ -93,18 +89,35 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) + + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + )(RayWorkerWrapper).remote( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) else: # Else, added to the list of workers. self.workers.append(worker) @@ -116,198 +129,127 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "GPU node.") # Get the set of GPU IDs used on each node. - driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) - worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) node_workers = defaultdict(list) node_gpus = defaultdict(list) - node_workers[driver_node_id].append(0) - node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) - # Set CUDA_VISIBLE_DEVICES for the driver and workers. - set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): - worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for (node_id, _) in worker_node_and_gpu_ids] + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - device_config = copy.deepcopy(self.device_config) - lora_config = copy.deepcopy(self.lora_config) - cache_config = copy.deepcopy(self.cache_config) - vision_language_config = copy.deepcopy(self.vision_language_config) - - # Initialize the actual workers with the Worker class. - for rank, (worker, (node_id, _)) in enumerate( - zip(self.workers, worker_node_and_gpu_ids), - start=1, - ): - local_rank = node_workers[node_id].index(rank) - worker.init_worker.remote( - lambda rank=rank, local_rank=local_rank: Worker( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=lora_config, - vision_language_config=vision_language_config, - )) - - # Initialize the driver worker with the Worker class. - driver_rank = 0 - driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=driver_local_rank, - rank=driver_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available KV blocks. + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> List[int]: - return self._run_workers("list_loras") + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, method: str, *args, - driver_args: Optional[List[Any]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. Can be used in the following + ways: + + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. + assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) + ray_worker_outputs = [] else: # Start the ray workers first. ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers + worker.execute_method.remote(method, *worker_args, + **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs - # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: @@ -325,6 +267,11 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + def _compiled_ray_dag(self): import pkg_resources required_version = "2.9" @@ -334,14 +281,15 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote.bind(input_data) - for worker in self.workers + worker.execute_model_compiled_dag_remote. + bind( # type: ignore[attr-defined] + input_data) for worker in self.workers ]) return forward_dag.experimental_compile() @@ -363,55 +311,22 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[List[Any]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Run the driver worker asynchronously. - driver_executor = make_async(getattr(self.driver_worker, method)) - coros.append(driver_executor(*driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) - all_outputs = await asyncio.gather(*coros) - return all_outputs - - async def execute_model_async( + async def _driver_execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/engine/ray_utils.py b/vllm/executor/ray_utils.py similarity index 72% rename from vllm/engine/ray_utils.py rename to vllm/executor/ray_utils.py index 70d5c9b1fae05..4704f5f1b1a10 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,47 +3,26 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip, set_cuda_visible_devices +from vllm.utils import get_ip, is_hip +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) try: import ray - class RayWorkerVllm: + class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" - def __init__(self, init_cached_hf_modules=False) -> None: - if init_cached_hf_modules: - from transformers.dynamic_module_utils import init_hf_modules - init_hf_modules() - self.worker = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn): - self.worker = worker_init_fn() - - def __getattr__(self, name): - return getattr(self.worker, name) - - def execute_method(self, method, *args, **kwargs): - try: - executor = getattr(self, method) - return executor(*args, **kwargs) - except Exception as e: - # exceptions in ray worker may cause deadlock - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - def get_node_ip(self) -> str: return get_ip() @@ -52,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def set_cuda_visible_devices(self, device_ids) -> None: - set_cuda_visible_devices(device_ids) - def execute_model_compiled_dag_remote(self, ignored): """Used only when compiled DAG is enabled.""" import torch @@ -67,11 +43,11 @@ def execute_model_compiled_dag_remote(self, ignored): return output except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray`.") - ray = None - RayWorkerVllm = None + logger.warning( + "Failed to import Ray with %r. For multi-node inference, " + "please install Ray with `pip install ray`.", e) + ray = None # type: ignore + RayWorkerWrapper = None # type: ignore def initialize_ray_cluster( @@ -91,7 +67,7 @@ def initialize_ray_cluster( """ if ray is None: raise ImportError( - "Ray is not installed. Please install Ray to use distributed " + "Ray is not installed. Please install Ray to use multi-node " "serving.") # Connect to a ray cluster. diff --git a/vllm/executor/torchrun_gpu_executor.py b/vllm/executor/torchrun_gpu_executor.py index 59a82ef61fc38..506c18c11186f 100644 --- a/vllm/executor/torchrun_gpu_executor.py +++ b/vllm/executor/torchrun_gpu_executor.py @@ -1,19 +1,18 @@ -import os -from typing import Dict, List, Optional +from typing import List, Optional, Tuple, Union import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.distributed import (broadcast_object_list, + tensor_model_parallel_all_gather) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_object_list, tensor_model_parallel_all_gather) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.utils import make_async logger = init_logger(__name__) @@ -29,43 +28,25 @@ class TorchrunGPUExecutor(GPUExecutor): def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - device_config: DeviceConfig, + device_config: DeviceConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig]) -> None: - self.local_rank = int(os.getenv("LOCAL_RANK", "0")) - self.is_driver_worker = self.local_rank == 0 + self.local_rank = envs.LOCAL_RANK + self.rank = envs.RANK + self.is_driver_worker = self.rank == 0 super().__init__(model_config, cache_config, parallel_config, - scheduler_config, device_config, lora_config, - vision_language_config, speculative_config) + scheduler_config, device_config, load_config, + lora_config, vision_language_config, + speculative_config) - def _init_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size > 1, ( - "TorchrunGPUExecutor only supports multiple GPUs.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - local_rank=self.local_rank, - rank=self.local_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=self.is_driver_worker, - ) + def _init_executor(self): + self.driver_worker = self._create_worker(local_rank=self.local_rank, + rank=self.rank) self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks, num_cpu_blocks = ( self.driver_worker.determine_num_available_blocks()) t = torch.tensor( @@ -76,17 +57,10 @@ def determine_num_available_blocks(self) -> tuple[int, int]: output = tensor_model_parallel_all_gather(t) return (torch.min(output[0]).item(), torch.min(output[1]).item()) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> List[Union[SamplerOutput, PoolerOutput]]: + output = self.driver_worker.execute_model(execute_model_req) if self.is_driver_worker: broadcast_object_list([output], src=0) else: @@ -100,16 +74,16 @@ class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + execute_model_req: ExecuteModelRequest, + ) -> List[Union[SamplerOutput, PoolerOutput]]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req) + if self.is_driver_worker: + broadcast_object_list([output], src=0) + else: + res = [None] + broadcast_object_list(res, src=0) + output = res[0] return output async def check_health_async(self) -> None: diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 0000000000000..85c9cd84f5ed5 --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,130 @@ +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(content=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TextTokensPrompt(TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt text.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" + +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: NotRequired[Optional[str]] + multi_modal_data: NotRequired[Optional["MultiModalData"]] diff --git a/vllm/logger.py b/vllm/logger.py index e5e46f5cce3fe..3c6bf0803a624 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,61 +1,154 @@ -# Adapted from -# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" +import datetime +import json import logging import os import sys +from functools import partial +from logging import Logger +from logging.config import dictConfig +from os import path +from typing import Dict, Optional -VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": VLLM_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, +} + + +def _configure_vllm_root_logger() -> None: + logging_config: Optional[Dict] = None -class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG - def __init__(self, fmt, datefmt=None): - logging.Formatter.__init__(self, fmt, datefmt) + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8", + mode="r") as file: + custom_config = json.loads(file.read()) - def format(self, record): - msg = logging.Formatter.format(self, record) - if record.message != "": - parts = msg.split(record.message) - msg = msg.replace("\n", "\r\n" + parts[0]) - return msg + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected Dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + if logging_config: + dictConfig(logging_config) -_root_logger = logging.getLogger("vllm") -_default_handler = None +def init_logger(name: str) -> Logger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" -def _setup_logger(): - _root_logger.setLevel(logging.DEBUG) - global _default_handler - if _default_handler is None: - _default_handler = logging.StreamHandler(sys.stdout) - _default_handler.flush = sys.stdout.flush # type: ignore - _default_handler.setLevel(logging.INFO) - _root_logger.addHandler(_default_handler) - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) - _default_handler.setFormatter(fmt) - # Setting this will avoid the message - # being propagated to the parent logger. - _root_logger.propagate = False + return logging.getLogger(name) -# The logger is initialized when the module is imported. +# The root logger is initialized when the module is imported. # This is thread-safe as the module is only imported once, # guaranteed by the Python GIL. -if VLLM_CONFIGURE_LOGGING: - _setup_logger() +_configure_vllm_root_logger() +logger = init_logger(__name__) -def init_logger(name: str): - # Use the same settings as above for root logger - logger = logging.getLogger(name) - logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) - if VLLM_CONFIGURE_LOGGING: - logger.addHandler(_default_handler) - logger.propagate = False - return logger + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, 'a') as f: + if event == 'call': + f.write(f"{datetime.datetime.now()} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + else: + f.write(f"{datetime.datetime.now()} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/logging/__init__.py b/vllm/logging/__init__.py new file mode 100644 index 0000000000000..b9aec380776f3 --- /dev/null +++ b/vllm/logging/__init__.py @@ -0,0 +1,5 @@ +from vllm.logging.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging/formatter.py b/vllm/logging/formatter.py new file mode 100644 index 0000000000000..b24b4e11d1fcb --- /dev/null +++ b/vllm/logging/formatter.py @@ -0,0 +1,15 @@ +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000000000..ffdc32b7339af --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,268 @@ +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA) +from vllm.lora.punica import bgmv, dispatch_bgmv_low_level + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs['lora_config'].fully_sharded_loras) + + return dec + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + bgmv(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + # now have column partitioned output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +def _mcp_apply(x, bias, layer): + """ + MergedColumnParallelLinearWithShardedLoRA and + QKVParallelLinearWithShardedLora share the same + LoRa weight application method. + + The main difference is the step by shard_size for lora_b which can + vary for QKVParallelLinearWithShardedLora but is constant for + MergedColumnParallelLinearWithShardedLoRA. + """ + # expecting 2 for column parallel and 3 for qkv + n = len(layer.lora_a_stacked) + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device) + for idx in range(n): + bgmv(buffers[idx], x, layer.lora_a_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0) + + buffers = tensor_model_parallel_all_gather(buffers) + left_offset = 0 + for idx in range(n): + shard_size = layer.lora_b_stacked[idx].shape[2] + dispatch_bgmv_low_level(output, buffers[idx], + layer.lora_b_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0, + left_offset, shard_size) + left_offset += shard_size + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None: + return lora_a + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, + output_start_idx:output_start_idx + output_shard_size], + lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] + ] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None: + return lora_a + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], + lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] + ] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0, + start_idx, shard_size) + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0505014753951..24b74476c3b85 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,8 +1,7 @@ # pylint: disable=unused-argument -import inspect import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -10,26 +9,59 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) +from vllm.distributed.utils import divide from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - split_tensor_along_last_dim) + VocabParallelEmbedding) if TYPE_CHECKING: pass +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # GPTQ/AWQ/SqueezeLLM + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True + condition = (not kwargs['lora_config'].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -115,6 +147,18 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + def create_lora_weights( self, max_loras: int, @@ -143,6 +187,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): """Sets the mapping indices.""" @@ -161,6 +206,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( self, @@ -218,9 +265,10 @@ def create_lora_weights( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None - self.embeddings_indices = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.embeddings_indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -252,6 +300,7 @@ def set_lora( self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[2] )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( @@ -260,6 +309,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -268,12 +318,13 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) @@ -297,42 +348,67 @@ def can_replace_layer(cls, source_layer: nn.Module, class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + + LoRA B is sliced for tensor parallelism. + """ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = torch.zeros( max_loras, 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + lora_a_output_size_per_partition, + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) - - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None self.output_dim = self.lora_b_stacked.shape[2] + # lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + def set_lora( self, index: int, @@ -341,12 +417,11 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) @@ -360,15 +435,15 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -392,7 +467,7 @@ def forward(self, input_): if not self.base_layer.skip_bias_add else None) # Matrix multiply. - output_parallel = self.apply_weights(input_, bias) + output_parallel = self.apply(input_, bias) if self.base_layer.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -402,11 +477,8 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -432,6 +504,7 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices and self.base_layer.output_sizes[0] @@ -440,28 +513,34 @@ def create_lora_weights( "LoRAColumnParallelLinear2Slice requires 2 slices with " "the same size.") self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + lora_a_output_size_per_partition, + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0] // 2, + self.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) - self.indices: Optional[torch.Tensor] = None self.output_dim = self.lora_b_stacked[0].shape[2] + # Lazily initialized. + self.indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -469,6 +548,24 @@ def reset_lora(self, index: int): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_b[0] is None or lora_b[1] is None: + return lora_b + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = [ + lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + ] + return lora_b + def set_lora( self, index: int, @@ -479,13 +576,8 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[0][:, - start_idx:end_idx], lora_b[1][:, - start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -502,10 +594,9 @@ def set_lora( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -517,6 +608,7 @@ def apply_weights(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -608,40 +700,44 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.base_layer.head_size) - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) # q, k, v self.lora_a_stacked = ( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + lora_a_output_size_per_partition, + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + lora_a_output_size_per_partition, + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + lora_a_output_size_per_partition, + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) self.lora_b_stacked = ( @@ -651,7 +747,7 @@ def create_lora_weights( self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -659,7 +755,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -667,7 +763,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) @@ -675,7 +771,8 @@ def create_lora_weights( self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # lazily initialized. + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -685,6 +782,30 @@ def reset_lora(self, index: int): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + lora_b_q, lora_b_k, lora_b_v = None, None, None + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + lora_b = [lora_b_q, lora_b_k, lora_b_v] + return lora_b + def set_lora( self, index: int, @@ -695,40 +816,24 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - else: - if lora_b[0] is not None: - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_b[1] is not None: - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_b[2] is not None: - self.lora_b_stacked[2][ - index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( - lora_b[2].T, non_blocking=True) + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_b[0] is not None: + lora_b_q = lora_b[0] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -743,10 +848,9 @@ def set_lora( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -758,6 +862,7 @@ def apply_weights(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -770,39 +875,61 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( ( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) + tp_size = get_tensor_model_parallel_world_size() + lora_b_output_size_per_partition = ( + self.output_size if not lora_config.fully_sharded_loras else + divide(self.output_size, tp_size)) + self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.base_layer.weight.shape[0], + lora_b_output_size_per_partition, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # Lazily initialized + self.indices: torch.Tensor + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + def set_lora( self, index: int, @@ -811,12 +938,10 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.base_layer.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.base_layer.weight.shape[1] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -831,14 +956,14 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) _apply_lora( x, self.lora_a_stacked, @@ -871,7 +996,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) + output_parallel = self.apply(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: @@ -888,9 +1013,11 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight + return self.base_layer.weight if hasattr( + self.base_layer, "weight") else self.base_layer.qweight @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -939,9 +1066,9 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 33024: + if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 33024") + "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras, @@ -971,9 +1098,10 @@ def create_lora_weights( dtype=self.dtype, device=self.device, ) - self.indices = None - self.indices_padded = None - self.indices_len = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.indices_padded: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -1005,6 +1133,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = sampler_indices @@ -1073,35 +1202,99 @@ def can_replace_layer(cls, source_layer: nn.Module, return False -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - cls - for cls in globals().values() if inspect.isclass(cls) - and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA -} - - -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: List, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: - for lora_cls in _all_lora_classes: - if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list, - model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret - return layer - - -def from_layer_logits_processor( - layer: LogitsProcessor, - lm_head: ParallelLMHead, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, -) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + # Lazily initialized + self.long_lora_indices: torch.Tensor + self.indices_len: List[int] + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = list( + lora_config.long_lora_scaling_factors + ) if lora_config.long_lora_scaling_factors else [] + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, + indices_len: List[int], + ): + self.long_lora_indices = long_lora_indices + self.indices_len = indices_len + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.long_lora_indices[:self.indices_len[4]]) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return type(source_layer) is LinearScalingRotaryEmbedding or type( + source_layer) is RotaryEmbedding + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 21c2196eb2739..d7794aa7cd35c 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -33,7 +33,7 @@ def __init__( def optimize(self) -> "LoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" if self.scaling == 1: - return + return self self.lora_b *= self.scaling self.scaling = 1 return self @@ -97,9 +97,9 @@ def __init__( self, module_name: str, rank: int, - lora_alphas: List[int], - lora_a: List[torch.Tensor], - lora_b: List[torch.Tensor], + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], scaling: Optional[List[float]] = None, ) -> None: super().__init__( @@ -108,17 +108,20 @@ def __init__( lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - scaling=scaling, + scaling=scaling, # type: ignore embeddings_tensor=None, ) self.lora_alphas = lora_alphas if scaling is None: - self.scaling = [ - lora_alpha / self.rank for lora_alpha in self.lora_alphas + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas ] @classmethod - def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + def pack( + cls, loras: List[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": """Pack a list of LoRAs into a single LoRA. If LoRA is None, it signifies that the submodule does not have a LoRA. @@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - scaling=[1 if lora is not None else None for lora in loras]) + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) return obj def optimize(self) -> "PackedLoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" for i in range(len(self.lora_b)): - if self.scaling[i] == 1 or self.lora_b[i] is None: + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore continue - self.lora_b[i] *= self.scaling[i] - self.scaling[i] = 1 + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore return self @property diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 62f1502458008..3e82856866d85 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,8 @@ import math import os import re -from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch @@ -11,10 +12,12 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_logits_processor) +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) @@ -22,10 +25,27 @@ _GLOBAL_LORA_ID = 0 +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + def convert_mapping( - mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], - max_loras: int, vocab_size: int, extra_vocab_size: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional[LongContextLoRAContext] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: """Converts LoRAMapping to index tensors. Args: @@ -34,6 +54,7 @@ def convert_mapping( max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: @@ -51,49 +72,78 @@ def convert_mapping( requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. + Used to index into each tensor. It contains length for + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). If long_lora doesn't + exist, it only contains first 4 entries. """ - indices = list(mapping.index_mapping).copy() - embedding_indices = indices.copy() - lora_indices = indices.copy() - prompt_mapping = [ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) + prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None - for i in range(len(indices)): + for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(indices[i]) - if indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if indices[i] > 0 else 0 - indices[i] = i + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx - - indices = torch.tensor([indices, lora_indices, embedding_indices], - dtype=torch.long, - device="cuda") - prompt_mapping = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, lora_indices, embedding_indices + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size) ]) embeddings_indices[embeddings_indices == -1] = max_loras - 1 base_indices = indices[1] - sampler_indices = prompt_mapping + sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded = ( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (sampler_indices_padded * len(sampler_indices_padded))) - indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1]) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len) + embeddings_indices, long_lora_indices, indices_len) def get_lora_id(): @@ -110,13 +160,35 @@ def __init__( lora_model_id: int, rank: int, loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank self.loras: Dict[str, LoRALayerWeights] = loras + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size @@ -138,6 +210,7 @@ def from_lora_tensors( dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, + scaling_factor: Optional[float] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": @@ -149,6 +222,7 @@ def from_lora_tensors( if module_name not in loras: lora_embeddings_tensor = None if embeddings: + assert embedding_modules is not None embeddings_module = next( (k for k in embedding_modules if k in module_name), None) @@ -171,6 +245,7 @@ def from_lora_tensors( else: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() + assert embedding_padding_modules is not None if any(name in module_name for name in embedding_padding_modules ) and target_embedding_padding is not None: @@ -185,13 +260,15 @@ def from_lora_tensors( for lora in loras.values(): lora.optimize() - return cls(lora_model_id, rank, loras) + return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor) @classmethod def from_local_checkpoint( cls, lora_dir: str, expected_lora_modules: List[str], + *, + max_position_embeddings: Optional[int] = None, lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -199,7 +276,23 @@ def from_local_checkpoint( embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": - """Create a LoRAModel from a local checkpoint.""" + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + max_position_embeddings: Max position embedding length. Used to + scaling the largest context length. If None, the lora model's + context length is not scaled. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") @@ -212,10 +305,14 @@ def from_local_checkpoint( target_modules = config["target_modules"] unexpected_modules = [] for module in target_modules: - if module not in expected_lora_modules: + # Compatible with more modules, such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + print(unexpected_modules, "modules") raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" @@ -237,6 +334,14 @@ def from_local_checkpoint( rank = config["r"] lora_alpha = config["lora_alpha"] + context_length = config.get("context_length", None) + scaling_factor = None + if context_length: + if max_position_embeddings is None: + max_position_embeddings = context_length + scaling_factor = float( + math.ceil(context_length / max_position_embeddings)) + return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, @@ -247,6 +352,7 @@ def from_local_checkpoint( dtype=dtype, embeddings=embeddings, target_embedding_padding=target_embedding_padding, + scaling_factor=scaling_factor, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, ) @@ -280,6 +386,7 @@ def __init__( self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") @@ -293,16 +400,25 @@ def __init__( self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.offsets = [] + self.long_lora_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} # 4 is the number of indicies tensors defined above # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices - self.indices_len = [None] * 4 + self.indices_len: List[Optional[int]] = [None] * 4 self.model: nn.Module = model if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = copy.deepcopy( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} @@ -310,7 +426,7 @@ def __init__( self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. self._active_loras: Dict[int, None] = {} - self._last_mapping = None + self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -340,8 +456,8 @@ def activate_lora( index, _ = first_free_slot self._active_loras[lora_id] = None lora_model = self._registered_loras[lora_id] - logger.debug( - f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) @@ -368,12 +484,32 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False - def _add_lora(self, lora: LoRAModel) -> bool: + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + + def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora + self._set_long_lora_context(lora) def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager CPU cache.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") @@ -385,15 +521,18 @@ def remove_lora(self, lora_id: int) -> bool: """Remove a LoRAModel from the manager CPU cache.""" # TODO: should we check active lora? self.deactivate_lora(lora_id) + if self.long_lora_context: + self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) return bool(self._registered_loras.pop(lora_id, None)) # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, + embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size) + self.lora_config.lora_extra_vocab_size, + self.long_lora_context) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( @@ -401,6 +540,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: self.embeddings_indices[:embeddings_indices. shape[0], :embeddings_indices.shape[1]].copy_( embeddings_indices) + if long_lora_offsets_tensor is not None: + self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self.long_lora_indices.zero_() # Maintain the reference self.indices_len[:] = indices_len @@ -416,14 +560,15 @@ def list_loras(self) -> Dict[int, LoRAModel]: def get_lora(self, lora_id: int) -> Optional[LoRAModel]: return self._registered_loras.get(lora_id, None) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): """Remove all LoRAModels from the manager.""" self._registered_loras.clear() self.lora_index_to_id = [None] * self.lora_slots self._active_loras.clear() def _create_lora_modules(self): - for module_name, module in self.model.named_modules(): + for module_name, module in self.model.named_modules( + remove_duplicate=False): if not self._match_target_modules(module_name): continue parts = module_name.split(".")[-1] @@ -432,6 +577,13 @@ def _create_lora_modules(self): self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) + # LinearScalingRotaryEmbeddingWithLora is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module = self.model.get_submodule( @@ -446,7 +598,8 @@ def _create_lora_modules(self): self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, self.sampler_indices_padded, - self.embeddings_indices, self.indices_len) + self.embeddings_indices, + self.long_lora_indices, self.indices_len) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) @@ -456,15 +609,18 @@ def create_dummy_lora( self, lora_id: int, rank: int, + scaling_factor: Optional[float], embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" - model = LoRAModel(lora_id, rank, {}) + model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name) or not isinstance( - module, BaseLayerWithLoRA): + module, BaseLayerWithLoRA) or isinstance( + module, LinearScalingRotaryEmbeddingWithLora): continue parts = module_name.split(".") if module_name not in self.packed_modules: + assert embedding_modules is not None if parts[-1] in embedding_modules: input_dim = (module.base_layer.org_vocab_size + self.lora_config.lora_extra_vocab_size if @@ -498,7 +654,7 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras = [] + subloras: List[Optional["LoRALayerWeights"]] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -536,7 +692,7 @@ def _register_packed_modules(self, module_full_name: str) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras = [] + replacement_loras: List[Optional[LoRALayerWeights]] = [] has_replacement = False for r in new_module_names: lora = lora_model.get_lora(r) @@ -555,13 +711,13 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: class LoRALRUCache(LRUCache[LoRAModel]): - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], - None]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: LoRAModel): - logger.debug(f"Removing LoRA. int id: {key}") + def _on_remove(self, key: int, value: LoRAModel): + logger.debug("Removing LoRA. int id: %d", key) self.deactivate_lora_fn(key) return super()._on_remove(key, value) @@ -590,6 +746,10 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: self._add_lora(lora) was_added = True diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index fc74269e55876..c87bed54726fc 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -49,6 +49,49 @@ def bgmv( punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) +def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, indicies: torch.LongTensor, + layer_idx: int, scale: float, y_offset: int, + y_slice_size: int): + """ + Same as `bgmv` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of + all of the transposed LoRA matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + punica_kernels.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + x.size(1), + y_slice_size, + y_offset, + ) + + def add_lora(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index bbbf4880ab81b..662774ffe09ae 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional @dataclass @@ -18,6 +19,7 @@ class LoRARequest: lora_name: str lora_int_id: int lora_local_path: str + long_lora_max_len: Optional[int] = None def __post_init__(self): if self.lora_int_id < 1: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 39e08f0412e4a..fcc7f24721939 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,11 +1,76 @@ -from typing import Tuple +from typing import List, Optional, Set, Tuple, Type from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + def replace_submodule(model: nn.Module, module_name: str, new_module: nn.Module) -> nn.Module: diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index a0868defbd3ca..d67ce67172e30 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Optional, Set, Type +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch @@ -16,15 +17,31 @@ class AbstractWorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" - def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, - vocab_size: int, lora_config: LoRAConfig, - device: torch.device): + def __init__(self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + max_position_embeddings: Optional[int] = None): self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = max_num_batched_tokens + self.max_position_embeddings = max_position_embeddings self.vocab_size = vocab_size self.device = device self.lora_config = lora_config + # If False, do not cache. If None, cache is empty. + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @abstractproperty def is_enabled(self) -> bool: ... @@ -37,7 +54,7 @@ def create_lora_manager( ... @abstractmethod - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: ... @@ -54,7 +71,7 @@ def remove_lora(self, lora_id: int) -> bool: ... @abstractmethod - def remove_all_loras(self) -> bool: + def remove_all_loras(self): ... @abstractmethod @@ -80,13 +97,21 @@ def __init__( embedding_modules: Dict[str, str], embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, ): - self._lora_manager: Optional[LoRAModelManager] = None self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules - super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, - lora_config, device) + # Lazily initialized by create_lora_manager. + self._lora_manager: LoRAModelManager + super().__init__( + max_num_seqs, + max_num_batched_tokens, + vocab_size, + lora_config, + device, + max_position_embeddings=max_position_embeddings, + ) @property def is_enabled(self) -> bool: @@ -104,15 +129,15 @@ def create_lora_manager( lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, ) - self._lora_manager: LoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: self._apply_loras(lora_requests) self._lora_manager.set_lora_mapping(lora_mapping) - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_that_exist = self.list_loras() loras_map = { lora_request.lora_int_id: lora_request @@ -149,6 +174,7 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, expected_lora_modules, + max_position_embeddings=self.max_position_embeddings, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, @@ -173,9 +199,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_loras(): return False - return self._lora_manager.add_lora( - self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank, self.embedding_modules)) + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._lora_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._lora_manager.add_lora(dummy_lora) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): @@ -188,7 +220,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self._lora_manager.remove_lora(lora_id) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): self._lora_manager.remove_all_loras() def list_loras(self) -> Set[int]: @@ -217,10 +249,10 @@ def create_lora_manager( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager: LRUCacheLoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request @@ -237,12 +269,14 @@ def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id not in self.list_loras(): # Remove before we load the new lora to save memory if len(self._lora_manager) + 1 > self._lora_manager.capacity: + assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) self._lora_manager.remove_oldest_lora() lora = self._load_lora(lora_request) loaded = self._lora_manager.add_lora(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + loaded = self._lora_manager.get_lora( + lora_request.lora_int_id) is not None self._lora_manager.activate_lora(lora_request.lora_int_id) return loaded diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000000000..0558d6c95d97b --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_decoding_backend: str, request: Union[CompletionRequest, + ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + if guided_decoding_backend == 'outlines': + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + return await get_lm_format_enforcer_guided_decoding_logits_processor( + request, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000000000..d0a5ca5592f9d --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,70 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_lm_format_enforcer_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if request.guided_json: + schema = _normalize_json_schema_object(request.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif request.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in request.guided_choice]) + elif request.guided_regex: + character_level_parser = RegexParser(request.guided_regex) + elif request.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + elif (request.response_format is not None + and request.response_format.type == "json_object"): + character_level_parser = JsonSchemaParser( + None) # None means any json object + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + if isinstance(schema, BaseModel): + return schema.model_json_schema() + raise AssertionError(f"Unsupported schema type {schema}") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py similarity index 85% rename from vllm/model_executor/guided_decoding.py rename to vllm/model_executor/guided_decoding/outlines_decoding.py index e56f74c7794fb..8403604286903 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,9 +12,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, - JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) class GuidedDecodingMode(Enum): @@ -54,9 +53,9 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm -async def get_guided_decoding_logits_processor( +async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -75,7 +74,8 @@ async def get_guided_decoding_logits_processor( result = await loop.run_in_executor(global_thread_pool, _get_cached_logits_processor, guide, - tokenizer, mode) + tokenizer, mode, + request.guided_whitespace_pattern) logits_processor = copy(result) # reset logits processor's internal state @@ -85,13 +85,13 @@ async def get_guided_decoding_logits_processor( def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] -) -> Tuple[str, GuidedDecodingMode]: +) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: json = request.guided_json if isinstance(json, dict): # turn dict into hashable string - json = json_dumps(json, sort_keys=True) + json = json_dumps(json) elif isinstance(json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same @@ -118,9 +118,10 @@ def _get_guide_and_mode( @lru_cache(maxsize=32) def _get_cached_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, - mode: GuidedDecodingMode): + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None]): if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.GRAMMAR: diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py similarity index 64% rename from vllm/model_executor/guided_logits_processors.py rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py index 035fe00037328..a131c6a1b92b4 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -13,13 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import math from collections import defaultdict -from typing import Callable, DefaultDict, Dict, List, Optional, Union +from functools import lru_cache +from typing import Callable, DefaultDict, Dict, List, Union import torch -from outlines.fsm.fsm import CFGFSM, RegexFSM +from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -27,49 +29,9 @@ class BaseLogitsProcessor: - def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def change_decoder( - decoder: Callable[[List[int]], str] - ) -> Callable[[List[int]], List[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" - - def new_decoder(inp_tokens: List[int]) -> List[str]: - return [decoder(inp_tokens)] - - return new_decoder - - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - - return tokenizer + def __init__(self): + # Child class should use initialize in their init. + self.fsm: FSM def init_state(self): """Initialize the FSM states.""" @@ -78,7 +40,6 @@ def init_state(self): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - seq_id = hash(tuple(input_ids)) if len(input_ids) == 0: @@ -96,7 +57,6 @@ def __call__(self, input_ids: List[int], device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) - return scores @@ -113,17 +73,16 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, - schema: Union[str, Dict, BaseModel], + def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Optional[str] = None): + whitespace_pattern: Union[str, None]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -167,6 +126,59 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + + def init_state(self): + """Initialize state with a CFGFSM copy.""" + super().init_state() + self.fsm = self.fsm.copy() + + +@lru_cache +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f569a5a49cbdf..d101aa323b0e1 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,11 +6,10 @@ import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide from vllm.model_executor.utils import set_weight_attrs @@ -68,6 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_tanh_and_mul(out, x) return out + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62b..2926c7d1c8a76 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_experts, fused_moe, fused_topk, get_config_file_name) __all__ = [ "fused_moe", + "fused_topk", + "fused_experts", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..3f3ccdafa88f3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,138 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..0bb423b28f5ab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..0c495e7e290c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..26bcbf26970c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..5b78c30f08b68 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..dbc624731f5cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..60a65724d68b9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..32c0c9da471cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..75f8b0017b9c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..34b916e574f88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a48b4385ba917..7a3c6ec773358 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,7 +9,7 @@ import triton.language as tl import vllm._moe_C as moe_kernels -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.logger import init_logger logger = init_logger(__name__) @@ -21,6 +21,8 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + a_scale_ptr, + b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, @@ -49,6 +51,7 @@ def fused_moe_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, + use_fp8: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -111,6 +114,10 @@ def fused_moe_kernel( b_ptrs = (b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -131,7 +138,10 @@ def fused_moe_kernel( mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -142,7 +152,10 @@ def fused_moe_kernel( other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = accumulator.to(compute_type) + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -192,17 +205,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1), ), - dtype=torch.int32, - device=topk_ids.device, - ) - expert_ids = torch.empty( - (topk_ids.numel() + num_experts, ), - dtype=torch.int32, - device=topk_ids.device, - ) + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) @@ -217,22 +228,26 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], -) -> None: +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if not use_fp8: + assert A_scale is None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) @@ -240,6 +255,8 @@ def invoke_fused_moe_kernel( A, B, C, + A_scale, + B_scale, topk_weights, sorted_token_ids, expert_ids, @@ -257,18 +274,21 @@ def invoke_fused_moe_kernel( C.stride(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, - compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + compute_type=compute_type, + use_fp8=use_fp8, **config, ) -def get_config_file_name(E: int, N: int) -> str: +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - return f"E={E},N={N},device_name={device_name}.json" + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" @functools.lru_cache -def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -280,14 +300,14 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N) + json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info( - f"Using configuration from {config_file_path} for MoE layer.") + logger.info("Using configuration from %s for MoE layer.", + config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -296,49 +316,16 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: return None -def fused_moe( +def fused_topk( hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert (hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] M, _ = hidden_states.shape - E, N, _ = w1.shape topk_weights = torch.empty(M, topk, @@ -362,12 +349,40 @@ def fused_moe( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + M, _ = hidden_states.shape + E, N, _ = w1.shape if override_config: config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2]) + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -407,37 +422,43 @@ def fused_moe( ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E) - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - ) + topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) + + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - ) + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8) if inplace: return torch.sum( @@ -447,3 +468,63 @@ def fused_moe( ) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5ad..8de0794158986 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): @@ -64,3 +64,8 @@ def forward( self.variance_epsilon, ) return out + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index feab230f033b6..29b5fe77ae705 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,19 +1,22 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from abc import abstractmethod +from typing import List, Optional import torch import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm import _custom_C +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.tuned_gemm import tgemm -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_hip logger = init_logger(__name__) @@ -26,23 +29,37 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -class LinearMethodBase(ABC): +class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - """Create weights for a linear layer.""" + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ raise NotImplementedError @abstractmethod - def apply_weights(self, - weights: Dict[str, torch.Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -57,22 +74,48 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - weight = Parameter(torch.empty(output_size_per_partition, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} - - def apply_weights(self, - weights: Dict[str, torch.Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if is_hip() and x.dtype == torch.float16 and x.view(-1, x.size(-1)).shape[0] == 1: + batched = False + if x.dim() == 3: + inp = x.view(-1, x.size(-1)) + batched = True + else: + inp = x + m, k = weight.shape[0], inp.shape[1] + out = torch.empty(inp.shape[0], + weight.shape[0], + dtype=inp.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weight, inp, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + if batched: + out = out.view(x.shape[0], x.shape[1], weight.shape[0]) + if bias is not None: + out = out + bias + return out if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -82,8 +125,8 @@ def apply_weights(self, return tgemm.mm(x, weight) -class ReplicatedLinear(torch.nn.Module): - """Replicated linear layer. +class LinearBase(torch.nn.Module): + """Base linear layer. Args: input_size: input dimension of the linear layer. @@ -91,17 +134,16 @@ class ReplicatedLinear(torch.nn.Module): bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( self, input_size: int, output_size: int, - bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -112,15 +154,44 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, self.input_size, + [self.output_size], self.input_size, + self.output_size, self.params_dtype) + if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -130,12 +201,19 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + -class ColumnParallelLinear(torch.nn.Module): +class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -152,42 +230,47 @@ class ColumnParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) - # Keep input parameters - self.input_size = input_size - self.output_size = output_size self.gather_output = gather_output + # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) - self.skip_bias_add = skip_bias_add - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + + if output_sizes is None: + output_sizes = [output_size] + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -200,6 +283,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -208,6 +295,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -215,8 +308,8 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -225,6 +318,14 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -244,92 +345,61 @@ class MergedColumnParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) - self._wsfs = [0, 0] - self._asfs = [0, 0] - self._osfs = [0, 0] - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs( - weight, { - "asf_loader": self.asf_loader, - "wsf_loader": self.wsf_loader, - "osf_loader": self.osf_loader, - "rescale": self.rescale, - }) - - def rescale(self): - if self._wsfs[0] < self._wsfs[1]: - factor = self._wsfs[0] / self._wsfs[1] - else: - factor = self._wsfs[1] / self._wsfs[0] - - weight_param = self.linear_weights["weight"] - param_data = weight_param.data - loaded_shard_id = 0 if self._wsfs[0] < self._wsfs[1] else 1 - tp_size = get_tensor_model_parallel_world_size() - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size - output_dim = getattr(weight_param, "output_dim", None) - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - param_data16 = torch.empty_like(param_data, dtype=torch.float16) - ops.convert_fp8( - param_data, param_data16, - torch.tensor(factor, dtype=torch.float32, device='cuda')) - ops.convert_fp8(param_data16, param_data, - torch.tensor(1, dtype=torch.float32, device='cuda')) - #param_data *= factor[0] - pass - - def wsf_loader(self, param: Parameter, sf: torch.Tensor, - loaded_shard_id: int): - self._wsfs[loaded_shard_id] = sf - param.data.copy_(max(self._wsfs)) - - def asf_loader(self, param: Parameter, sf: torch.Tensor, - loaded_shard_id: int): - self._asfs[loaded_shard_id] = sf - param.data.copy_(max(self._asfs)) - - def osf_loader(self, param: Parameter, sf: torch.Tensor, - loaded_shard_id: int): - self._osfs[loaded_shard_id] = 1 / sf.item() - param.data.copy_(max(self._osfs)) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - if (param.data.dtype == torch.float8_e4m3fnuz - and loaded_weight.dtype != torch.int8): - # Quantized weights for this layer have already been loaded - return - if (param.data.dtype != loaded_weight.dtype - and loaded_weight.dtype == torch.int8 - and param.data.dtype != torch.float8_e4m3fnuz): - param.data = torch.empty_like(param.data, - dtype=torch.float8_e4m3fnuz) param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -343,14 +413,13 @@ def weight_loader(self, current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -365,15 +434,14 @@ def weight_loader(self, if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -382,6 +450,24 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -389,11 +475,15 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape - if param_data.dtype == torch.float8_e4m3fnuz: - loaded_weight[loaded_weight == -128] = 0 - assert param_data.is_contiguous() and loaded_weight.is_contiguous() - loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) param_data.copy_(loaded_weight) @@ -418,20 +508,18 @@ class QKVParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -451,26 +539,51 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - if param.data.dtype == torch.float8_e4m3fnuz and loaded_weight.dtype == torch.float16: - # Quantized weights for this layer have already been loaded - return param_data = param.data output_dim = getattr(param, "output_dim", None) - - if param.data.dtype != loaded_weight.dtype and loaded_weight.dtype == torch.int8: - param.data = torch.empty_like(param.data, - dtype=torch.float8_e4m3fnuz) - loaded_weight[loaded_weight == -128] = 0 - loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) - self.weight_loader(param, loaded_weight) - return + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) if loaded_shard_id is None: # Loaded weight is already packed. @@ -488,14 +601,14 @@ def weight_loader(self, ] packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -506,6 +619,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -517,6 +632,7 @@ def weight_loader(self, shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) @@ -524,8 +640,7 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -538,6 +653,23 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -545,11 +677,18 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) -class RowParallelLinear(torch.nn.Module): +class RowParallelLinear(LinearBase): """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -572,45 +711,36 @@ class RowParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - # Keep input parameters - self.input_size = input_size - self.output_size = output_size + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - self.skip_bias_add = skip_bias_add - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) - + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -626,6 +756,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -634,6 +768,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -648,8 +792,8 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: @@ -662,3 +806,11 @@ def forward(self, input_): output = output_ output_bias = self.bias return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4d96b91422a7c..d81c00437d75f 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,12 +1,12 @@ """A layer that compute logits from hidden_stats.""" +import inspect from typing import Optional import torch import torch.nn as nn +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.layers.tuned_gemm import tgemm -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -72,6 +72,12 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = logits[:, :self.org_vocab_size] return logits + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + def _prune_hidden_states( hidden_states: torch.Tensor, @@ -85,21 +91,37 @@ def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_processed = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True - for seq_id in seq_ids: + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, + past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, + logits_row) + logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + if found_logits_processors: - assert logits_row_idx == logits.shape[0] + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] return logits diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index a19e9461f41f7..d08ae6064aa2a 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -29,8 +29,8 @@ def _multi_split_sample( sampled_tokens_size: Tuple[int, int], sampled_logprobs_size: Tuple[int, int], sample_indices: torch.Tensor, + logprobs: torch.Tensor, *, - logprobs: Optional[torch.Tensor] = None, modify_greedy_probs: bool = False, save_logprobs: bool = False, ): @@ -167,6 +167,7 @@ def sample( sampled_logprobs_size = (0, 0) logprobs = probs + assert logprobs is not None if _save_modified_probs: sampled_modified_probs_size = sampled_tokens_size else: diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000000000..445b30b8c6e9b --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,56 @@ +from enum import IntEnum + +import torch +import torch.nn as nn + +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + if self.pooling_type == PoolingType.LAST: + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_flat_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index de4c3f14e2d4a..b963576aa4471 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,29 +1,48 @@ -from typing import Type +from typing import Dict, Type +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp8_rocm import Fp8RocmConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -from vllm.model_executor.layers.quantization.fp8_rocm import Fp8RocmConfig +from vllm.utils import is_hip -_QUANTIZATION_CONFIG_REGISTRY = { +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, - "marlin": MarlinConfig, - "fp8": Fp8RocmConfig + "sparseml": CompressedTensorsConfig, } def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: - if quantization not in _QUANTIZATION_CONFIG_REGISTRY: + if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") - return _QUANTIZATION_CONFIG_REGISTRY[quantization] + return QUANTIZATION_METHODS[quantization] __all__ = [ "QuantizationConfig", "get_quantization_config", + "QUANTIZATION_METHODS", ] diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000000000..730595c3d36d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,376 @@ +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> str: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": + True, + "output_partition_sizes": + torch.tensor(output_partition_sizes, device='cpu'), + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + None) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf50..f4fc7ce020e95 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,11 +3,11 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class AWQConfig(QuantizationConfig): @@ -62,8 +62,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_linear_method(self) -> "AWQLinearMethod": - return AWQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + if isinstance(layer, LinearBase): + return AWQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] @@ -79,15 +82,18 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " @@ -136,19 +142,21 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } - - def apply_weights(self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +171,5 @@ def apply_weights(self, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be956..e7de283b562a6 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,9 +1,34 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch +from torch import nn -from vllm.model_executor.layers.linear import LinearMethodBase + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return class QuantizationConfig(ABC): @@ -41,6 +66,17 @@ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" raise NotImplementedError + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + @staticmethod def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: """Get a value from the model's quantization config.""" @@ -51,8 +87,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: - """Get the linear method to use for the quantized linear layer.""" + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000000..19e464bd64325 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, ignore=ignore) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Refactor as additional cases are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric and \ + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for + # now. + return CompressedTensorsW8A8StaticTensor() + raise NotImplementedError( + "Scheme not supported. Only CUDA, 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") + + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer}.") + + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000000..831905b63e2c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,5 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000000000..3a5904208656e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 0000000000000..0cfac13d1ca25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,39 @@ +from typing import Callable, List + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 0000000000000..64a88b01cd260 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,130 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as custom_ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + + # Note on input/weight scales and zero_points + # + # When the scales have a single value, it is required that they be + # on the CPU for 2 reasons, + # 1. Performance: + # When the scales (input_scale/weight_scales) have only a single + # value, we perform a scalar broadcast of that value during the + # quant/dequant operations. The "quant" and the "gemm+dequant" + # kernels accept the Scalar by-value. These tensors are allocated + # on the CPU in order to avoid the GPU-to-CPU copy when passing + # by-value. + # + # 2. CUDA Graphs: + # CUDA Graphs don't support GPU-to-CPU copy operations during + # stream capture. + # + # TODO: zero-points are not supported yet. But we expect a similar + # pattern. + + is_tensor_partitioned = len(output_partition_sizes) != 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" + + input_scale = Parameter(torch.empty(1, + device="cpu", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "weight_loader": weight_loader, + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, { + "weight_loader": weight_loader, + "ignore_warning": True, + }) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, { + "weight_loader": weight_loader, + "ignore_warning": True, + }) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs( + weight_scale, { + "weight_loader": weight_loader, + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes, + "ignore_warning": True, + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, { + "weight_loader": weight_loader, + "ignore_warning": True + }) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + + # Input quantize + x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000000000..31cdffbcf0ab9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000000000..b084b9cee4983 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,313 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + if isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + # ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="act_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + return param[shard_id], loaded_weight + + def process_weights_after_loading(self, layer: Module) -> None: + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.act_scale = None + return + + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(layer.weight[start:end, :], + layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # ACT_SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the act_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.act_scale = None + elif self.quant_config.activation_scheme == "static": + if not all_close_1d(layer.act_scale): + raise ValueError( + "All the act_scales for the logical weights of a layer " + f"must be equal. But got {layer.act_scale}") + layer.act_scale = Parameter(layer.act_scale.max(), + requires_grad=False) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}") + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.act_scale is None and x_scale computed from x. + # If static, layer.act_scale is scalar and x_scale set to act_scale. + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + bias=bias, + ) + + return torch.narrow(output, 0, 0, x.shape[0]) + + +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint.") + del layer.kv_scale + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index 55642ae23ac6d..caa53fb6ceee8 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -1,30 +1,42 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple -from safetensors import safe_open +from typing import Any, Dict, List, Optional, Tuple, Union, Iterator + import torch +from torch.nn import Module from torch.nn.parameter import Parameter import torch.nn.functional as F -from vllm.model_executor.layers.linear import LinearMethodBase +from safetensors import safe_open + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, -) -from vllm._C import ops + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + import pandas as pd import os +try: + from vllm._C import ops as vllm_ops +except ImportError: + pass + +logger = init_logger(__name__) + class Fp8RocmConfig(QuantizationConfig): - def __init__(self, config) -> None: - self.quantized_weights_path = config["quantized_weights"] + def __init__(self) -> None: + # self.quantized_weights_path = config["quantized_weights"] self._tuned = {} self._stats = {} gemm_type = os.getenv("FP8_GEMM", "fp8_16") #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": - self.gemm_method = Fp8RocmLinearLayer.apply_fp8_8 - tuned_filename = "/tuned_fp8_8.csv" + self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 + tuned_filename = "/projects/tuned_fp8_8.csv" elif gemm_type == "fp8_16": - self.gemm_method = Fp8RocmLinearLayer.apply_fp8_16 - tuned_filename = "/tuned_fp8_16.csv" + self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 + tuned_filename = "/projects/tuned_fp8_16.csv" else: raise Exception(f"Unknown fp8 gemm type: {gemm_type}") try: @@ -37,12 +49,12 @@ def __init__(self, config) -> None: m = shape["M"] n = shape["N"] k = shape["K"] - algo = shape["solidx"] + algo = shape["algo"] self._tuned[(m, n, k)] = algo - @staticmethod - def get_config_filenames() -> List[str]: - return ["serenity_config.json"] + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] @classmethod def from_config(cls, config) -> "Fp8RocmConfig": @@ -55,91 +67,143 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod # Need to figure it out def get_min_capability(cls) -> int: - return 80 + return 94 @classmethod def get_name(cls) -> str: return "Fp8Rocm" - def get_linear_method(self) -> "Fp8RocmLinearLayer": - return Fp8RocmLinearLayer(self) + def get_quant_method(self, + layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: + if isinstance(layer, LinearBase): + return Fp8RocmLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] -def set_weight_attrs( - weight: torch.Tensor, - weight_attrs: Optional[Dict[str, Any]], -): - """Set attributes on a weight tensor. - - This method is used to set attributes on a weight tensor. This method - will not overwrite existing attributes. - - Args: - weight: The weight tensor. - weight_attrs: A dictionary of attributes to set on the weight tensor. - """ - if weight_attrs is None: - return - for key, value in weight_attrs.items(): - assert not hasattr( - weight, key - ), f"Overwriting existing tensor attribute: {key}" - setattr(weight, key, value) +class Fp8RocmLinearMethod(LinearMethodBase): + def __init__(self, config: Fp8RocmConfig): + self._config = config -class Fp8RocmLinearLayer(LinearMethodBase): - def __init__(self, config: Fp8RocmConfig) -> None: - self._config = config + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) - def get_tensor(self) -> Iterator[Tuple[str, torch.Tensor]]: - with safe_open( - self._config.quantized_weights_path, framework="pt" - ) as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: - # for a, b in self.get_tensor(): - # print(f"{a}: {b.shape}") - # pass + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) weight = Parameter( torch.empty( output_size_per_partition, input_size_per_partition, - dtype=params_dtype, + dtype=torch.float8_e4m3fnuz, ), requires_grad=False, ) - # orig_weight = Parameter(torch.empty(output_size_per_partition, - # input_size_per_partition, - # dtype=params_dtype), - # requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - # set_weight_attrs(orig_weight, {"input_dim": 1, "output_dim": 0}) - return { - "weight": weight, - # "orig_weight": orig_weight, - "activation_scaling_factor": Parameter( - torch.empty(1, dtype=torch.float32, device="cuda") - ), - "weights_scaling_factor": Parameter( - torch.empty(1, dtype=torch.float32, device="cuda") - ), - "output_scaling_factor": Parameter( - torch.empty(1, dtype=torch.float32, device="cuda") - ) - } + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0 + }) + + self._create_scale_param( + scale_name="weights_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param( + scale_name="activation_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param( + scale_name="output_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + + def process_weights_after_loading(self, layer: Module) -> None: + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + layer.activation_scaling_factor = Parameter(layer.activation_scaling_factor.max(), + requires_grad=False) + layer.output_scaling_factor = Parameter(layer.output_scaling_factor.reciprocal().max(), + requires_grad=False) + + max_w_scale = layer.weights_scaling_factor.max() + if len(layer.logical_widths) > 1: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = _per_tensor_dequantize(layer.weight[start:end, :], + layer.weights_scaling_factor[idx]) + + layer.weight[start:end, :] = _per_tensor_quantize( + weight_dq, max_w_scale) + start = end + layer.weights_scaling_factor = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight, requires_grad=False) + + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + # To handle the scalar loaded tensor + if loaded_weight.numel() == 1 and len(loaded_weight.shape) != 0: + loaded_weight = torch.scalar_tensor(loaded_weight[0]) + + return param[shard_id], loaded_weight def apply_fp8_16( self, @@ -150,7 +214,7 @@ def apply_fp8_16( osf: torch.Tensor, ) -> torch.Tensor: x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) - ops.convert_fp8(x, x8, asf) + vllm_ops.convert_fp8(x8, x, asf) m = weight.shape[0] n = x.shape[0] k = x.shape[1] @@ -159,19 +223,17 @@ def apply_fp8_16( if algo is None: import os - # print(f"Not found: {m} {n} {k}") if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/fp8_shapes.csv") + df = pd.read_csv("/projects/fp8_shapes.csv") except: df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] ).drop_duplicates() - df.to_csv("/fp8_shapes.csv", index=False) - # print(f"{m},{n},{k}") + df.to_csv("/projects/fp8_shapes.csv", index=False) algo = 0 - res = ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) + res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res def apply_fp8_8( @@ -185,7 +247,7 @@ def apply_fp8_8( ) -> torch.Tensor: assert not bias x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) - ops.convert_fp8(x, x8, asf) + vllm_ops.convert_fp8(x8, x, asf) m = weight.shape[0] n = x.shape[0] k = x.shape[1] @@ -194,39 +256,48 @@ def apply_fp8_8( if algo is None: import os - # print(f"Not found: {m} {n} {k}") if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/fp8_shapes.csv") + df = pd.read_csv("/projects/fp8_shapes.csv") except: df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] ).drop_duplicates() - df.to_csv("/fp8_shapes.csv", index=False) - # print(f"{m},{n},{k}") + df.to_csv("/projects/fp8_shapese.csv", index=False) algo = 0 - res = ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) + res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) res16 = torch.empty_like(res, dtype=torch.float16) - ops.convert_fp8(res, res16, 1/osf) + vllm_ops.convert_fp8(res16, res, 1/osf) return res16 - def apply_weights( + def apply( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - weight: torch.Tensor = weights["weight"] + weight: torch.Tensor = layer.weight if weight.dtype == torch.float8_e4m3fnuz: - asf: torch.Tensor = weights["activation_scaling_factor"] * 2 - wsf: torch.Tensor = weights["weights_scaling_factor"] * 2 - osf: torch.Tensor = weights["output_scaling_factor"] / 2 - #my_osf: torch.Tensor = self._config.factor / weights["my_osf"] - #with open("ratio.txt", "a") as f: - # f.write(f'{weights["output_scaling_factor"].item()},{weights["my_osf"].item()}\n') - #return self.test(weight, asf, wsf, osf, x, weights["my_osf"], bias) + asf: torch.Tensor = layer.activation_scaling_factor * 2 + wsf: torch.Tensor = layer.weights_scaling_factor * 2 + osf: torch.Tensor = layer.output_scaling_factor / 2 + return self._config.gemm_method(self, x, weight, asf, wsf, osf) - return F.linear(x, weight, bias) \ No newline at end of file + return F.linear(x, weight, bias) + + +def _per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fnuz) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fnuz) + + +def _per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed811..ae9f7019f0592 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,11 +6,11 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class GPTQConfig(QuantizationConfig): @@ -63,8 +63,11 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": desc_act = cls.get_from_keys(config, ["desc_act"]) return cls(weight_bits, group_size, desc_act) - def get_linear_method(self) -> "GPTQLinearMethod": - return GPTQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -89,18 +92,21 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) if (output_size_per_partition % self.quant_config.pack_factor.numerator != 0): raise ValueError( @@ -179,37 +185,40 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } - - def apply_weights(self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 0000000000000..ae440743fdf8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,457 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_SUPPORTED_SYM = [True] + + +# Permutations for Marlin scale shuffling +def get_scale_perms(num_bits): + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def get_pack_factor(num_bits): + assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + + # Verify + if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: + raise ValueError( + f"Marlin does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") + + # Init + self.pack_factor = get_pack_factor(weight_bits) + self.tile_size = GPTQ_MARLIN_TILE + self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N + self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K + self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, group_size, desc_act, is_sym) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + sym = quant_config.get("sym", None) + desc_act = quant_config.get("desc_act", None) + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies marlin constraints. + return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES + and sym in GPTQ_MARLIN_SUPPORTED_SYM) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Validate dtype + if params_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError(f"The params dtype must be float16 " + f"or bfloat16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_thread_n != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {self.quant_config.min_thread_n}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_thread_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {self.quant_config.min_thread_k}.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}.") + + # Detect sharding of scales/zp + + # By default, no sharding over "input dim" + scales_and_zp_size = input_size // group_size + scales_and_zp_input_dim = None + + if self.quant_config.desc_act: + # Act-order case + assert self.quant_config.group_size != -1 + + is_k_full = input_size_per_partition == input_size + + else: + # No act-order case + + # K is always full due to full alignment with + # group-size and shard of scales/zp + is_k_full = True + + # If this is a row-parallel case, then shard scales/zp + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + scales_and_zp_size = input_size_per_partition // group_size + scales_and_zp_input_dim = 0 + + # Init buffers + + # Quantized weights + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + **extra_weight_attrs, + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + # Activation order + g_idx = Parameter( + torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs( + g_idx, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, + ) + + g_idx_sort_indices = torch.empty( + g_idx.shape, + dtype=torch.int32, + ) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device="meta", + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + # Allocate marlin workspace + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_thread_n) * self.quant_config.max_parallel + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + requires_grad=False) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + layer.g_idx_sort_indices = g_idx_sort_indices + layer.workspace = workspace + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.is_k_full = is_k_full + layer.marlin_state = GPTQMarlinState.REPACK + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + + size_m = reshaped_x.shape[0] + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + full_size_k = layer.input_size + + out_shape = x.shape[:-1] + (part_size_n, ) + + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + cur_device = layer.qweight.device + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) + + sorted_g_idx = layer.g_idx[g_idx_sort_indices] + + replace_tensor("g_idx", sorted_g_idx) + replace_tensor("g_idx_sort_indices", g_idx_sort_indices) + + else: + # Reset g_idx related tensors + layer.g_idx = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) + layer.g_idx_sort_indices = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) + + # Repack weights + marlin_qweight = ops.gptq_marlin_repack( + layer.qweight, + layer.g_idx_sort_indices, + part_size_k, + part_size_n, + self.quant_config.weight_bits, + ) + replace_tensor("qweight", marlin_qweight) + + # Permute scales + scales_size_k = part_size_k + scales_size_n = part_size_n + if self.quant_config.desc_act: + scales_size_k = full_size_k + + marlin_scales = marlin_permute_scales( + layer.scales, + scales_size_k, + scales_size_n, + self.quant_config.group_size, + self.quant_config.weight_bits, + ) + replace_tensor("scales", marlin_scales) + + output = ops.gptq_marlin_gemm( + reshaped_x, + layer.qweight, + layer.scales, + layer.g_idx, + layer.g_idx_sort_indices, + layer.workspace, + self.quant_config.weight_bits, + size_m, + part_size_n, + part_size_k, + layer.is_k_full, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000000000..6bcfc405afe71 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,291 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +GPTQ_MARLIN_24_SUPPORTED_SYM = [True] + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + # Verify + if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin_24 does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + "marlin_tile_size": self.quant_config.tile_size, + }, + ) + + # Meta + meta = Parameter( + torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + requires_grad=False, + ) + set_weight_attrs( + meta, + { + "input_dim": 0, + "packed_dim": 1, + "pack_factor": 1, + "output_dim": 1, + "marlin_tile_size": 2, + }, + ) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + scales = Parameter( + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": None if input_groups == 1 else 0, + "output_dim": 1, + }, + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) + + layer.register_parameter("B_24", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("B_meta", meta) + set_weight_attrs(meta, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.weight_bits, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf4..3613c9d9ecf2a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,11 +3,14 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) class MarlinConfig(QuantizationConfig): @@ -72,8 +75,30 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(group_size) - def get_linear_method(self) -> "MarlinLinearMethod": - return MarlinLinearMethod(self) + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return MarlinLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -91,12 +116,14 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -104,6 +131,7 @@ def create_weights( f"The params dtype must be float16, but got {params_dtype}") # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.min_n_threads != 0: raise ValueError( f"Weight output_size_per_partition = " @@ -187,21 +215,22 @@ def create_weights( dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) - def apply_weights( + def apply( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1f..207dbcee8afc5 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,11 +3,11 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hip @@ -51,14 +51,17 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - def get_linear_method(self) -> "SqueezeLLMLinearMethod": - return SqueezeLLMLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return SqueezeLLMLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] -class SqueezeLLMLinearMethod(LinearMethodBase): +class SqueezeLLMLinearMethod(QuantizeMethodBase): """Linear method for SqueezeLLM. Args: @@ -68,15 +71,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( torch.empty( input_size_per_partition // self.quant_config.pack_factor, @@ -103,17 +109,18 @@ def create_weights(self, input_size_per_partition: int, set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } - - def apply_weights(self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +133,5 @@ def apply_weights(self, ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/format_24.py new file mode 100644 index 0000000000000..01c8cf789204b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/format_24.py @@ -0,0 +1,308 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# + +import torch + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py new file mode 100644 index 0000000000000..12e77cb710687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms_24(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +marlin_24_perm = {} +marlin_24_scale_perm = {} +marlin_24_scale_perm_single = {} +for num_bits in [4, 8]: + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py new file mode 100644 index 0000000000000..76bd2ff7c724e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +marlin_perm = {} +marlin_scale_perm = {} +marlin_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = get_perms(num_bits) + marlin_perm[num_bits] = perm + marlin_scale_perm[num_bits] = scale_perm + marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000..0d027d0620ab3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,224 @@ +"""This file is used for /tests and /benchmarks""" +import random + +import numpy +import torch + +from vllm.model_executor.layers.quantization.utils.format_24 import ( + mask_creator, sparse_semi_structured_from_dense_cutlass) +from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( + marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm, marlin_scale_perm, marlin_scale_perm_single) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_pack_factor, quantize_weights, sort_weights) + +__cuda_arch = torch.cuda.get_device_capability() + +MARLIN_TILE = 16 + + +def is_marlin_supported(): + return __cuda_arch[0] >= 8 + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, + scale_perm_single): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, marlin_24_perm[num_bits]) + marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_24_scale_perm[num_bits], + marlin_24_scale_perm_single[num_bits]) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000..177cb23f63cf4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,146 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +SUPPORTED_NUM_BITS = [4, 8] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def get_pack_factor(num_bits): + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index ecd2bd0fce3a3..1f2ab7e2870ca 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): https://arxiv.org/pdf/2302.01318.pdf. """ - def __init__(self, strict_mode: bool = False): + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): """Create a rejection sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. """ super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -116,6 +122,7 @@ def forward( draft_token_ids, bonus_token_ids, ) + return output_token_ids def _batch_modified_rejection_sampling( @@ -144,6 +151,7 @@ def _batch_modified_rejection_sampling( recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) + # NOTE: the recovered_probs are overwritten by this method. recovered_token_ids = _multinomial(recovered_probs, num_samples=1).reshape( batch_size, k) @@ -307,6 +315,13 @@ def _create_output( output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, bonus_token_ids, -1) + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 + # Fill the recovered token ids. output.mul_(~after_false_mask).add_( recovered_token_ids.mul(after_false_mask)) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d80e73bbe39e9..d03903d206d33 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -53,6 +53,7 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size @@ -60,9 +61,10 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style + self.dtype = dtype cache = self._compute_cos_sin_cache() - cache = cache.to(torch.get_default_dtype()) + cache = cache.to(dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -108,7 +110,8 @@ def _forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device, dtype=query.dtype) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -142,7 +145,8 @@ def forward( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: @@ -155,10 +159,39 @@ def forward( self.cos_sin_cache, self.is_neox_style) return query, key + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + Credits to the Reddit user /u/kaiokendev """ @@ -170,16 +203,22 @@ def __init__( base: int, is_neox_style: bool, scaling_factors: Union[List[float], float], + dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] - self.scaling_factors = scaling_factors + self.scaling_factors: List[float] = scaling_factors # noqa super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - cache_list = [] + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] for scaling_factor in self.scaling_factors: # NOTE(woosuk): self.max_position_embeddings is the original # maximum length before applying the rope scaling. @@ -193,9 +232,25 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) return torch.cat(cache_list, dim=0) + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. @@ -211,10 +266,11 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -247,11 +303,12 @@ def _yarn_find_correction_dim(num_rotations: int, # Find dim range bounds based on rotations -def _yarn_find_correction_range(low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> int: +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( @@ -290,11 +347,12 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, *, extrapolation_factor: float = 1, attn_factor: float = 1, - beta_fast: float = 32, - beta_slow: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -305,7 +363,7 @@ def __init__( self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( @@ -336,6 +394,115 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: return cache +class Phi3SuScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: float = 1.1, + long_mscale: float = 1.225, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \ + head_size ({rotary_dim}!={head_size}).") + if is_neox_style is False: + raise ValueError( + "`Phi3SuScaledRotaryEmbedding` only supports neox_style.") + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", + short_cache, + persistent=False) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", + long_cache, + persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -346,27 +513,39 @@ def get_rope( base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - tuple(rope_scaling.items()) if rope_scaling is not None else None) + rope_scaling_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + is_neox_style, dtype) else: scaling_type = rope_scaling["type"] - scaling_factor = rope_scaling["factor"] + if scaling_type != "su": + scaling_factor = rope_scaling["factor"] if scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] @@ -379,8 +558,22 @@ def get_rope( rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, - scaling_factor, + scaling_factor, dtype, **extra_kwargs) + elif scaling_type == "su": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3SuScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3a..a84f562909d50 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -7,12 +7,16 @@ from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors) -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceData, SequenceGroupOutput, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceOutput) +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -27,18 +31,35 @@ class Sampler(nn.Module): 6. Sample the next tokens. Here, each sequence group within the batch can have different sampling parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. """ + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False + def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: + """ + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ assert logits is not None _, vocab_size = logits.shape - # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - # have not been generated yet logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. @@ -69,17 +90,47 @@ def forward( # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata, - sampling_tensors) + sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + on_device_tensors = None + # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + return _build_sampler_output(sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + # Modify greedy probs if include_gpu_probs_tensor is set. + return self.include_gpu_probs_tensor def _get_bin_counts_and_mask( @@ -103,35 +154,46 @@ def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ # list of indices in logits that will be set to -inf - logits_to_penalize = [] - start_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_to_penalize: List[Tuple[int, int]] = [] + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + + start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens - if min_tokens > 0: + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: seqs_to_penalize = [] - for i, seq_id in enumerate(seq_ids): - seq_data = sampling_metadata.seq_data[seq_id] + for j, seq_id in enumerate(seq_ids): + seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: - seqs_to_penalize.append(i) + seqs_to_penalize.append(j) if seqs_to_penalize: # convert to the index into logits - seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] - # use set() to remove any duplicates - token_ids_to_penalize = set(sampling_params.stop_token_ids + - [sampling_params.eos_token_id]) + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] # itertools.product pairs each seq index with every token id logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) - start_idx += len(seq_ids) - if logits_to_penalize: # use zip and * to group indices along each dimension # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") + # verifies that no rows in logits were missed unexpectedly + assert logits_applied == logits.shape[0] return logits @@ -208,14 +270,30 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], + selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ samples = samples.tolist() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: - seq_ids, _ = seq_group + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids num_parent_seqs = len(seq_ids) assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") @@ -227,16 +305,33 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], + selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 - results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. @@ -254,11 +349,20 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], + selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: + """Run beam sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + logprobs: (num_selected_samples, vocab_size,) A tensor of logprob + on selected sample indices. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # We sample 2 * beam_width candidates to make sure that with high # probability we can get `beam_width` candidates in addition to # the finished sequences for the next iteration. See @@ -269,9 +373,14 @@ def _beam_search_sample( # NOTE: Beam search is not vectorized, so its speed can be slower than # other sampling methods. sample_idx = 0 - results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + is_prompt = seq_group.is_prompt + seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params num_parent_seqs = len(seq_ids) beam_width = sampling_params.best_of seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] @@ -285,15 +394,16 @@ def _beam_search_sample( next_token_ids = next_token_ids.tolist() else: # Generation phase. - cumulative_logprobs = [ - seq_data[seq_id].cumulative_logprob for seq_id in seq_ids + cumulative_logprobs: List[int] = [ + seq_group.seq_data[seq_id].cumulative_logprob + for seq_id in seq_ids ] - cumulative_logprobs = torch.tensor( + cumulative_logprobs_tensor = torch.tensor( cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device) seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs.unsqueeze(dim=1)) + cumulative_logprobs_tensor.unsqueeze(dim=1)) _, topk_ids = torch.topk(seq_group_logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() @@ -314,8 +424,7 @@ def _beam_search_sample( def _multinomial( probs: torch.Tensor, num_samples: int, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, - generators: Optional[List[torch.Generator]] = None, + seq_groups: Optional[List[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also @@ -331,9 +440,11 @@ def _multinomial( q.exponential_() else: sample_idx = 0 - for (seq_ids, _), generator in zip(seq_groups, generators): + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids next_sample_idx = sample_idx + len(seq_ids) * num_samples - q[sample_idx:next_sample_idx].exponential_(generator=generator) + q[sample_idx:next_sample_idx].exponential_( + generator=seq_group.generator) sample_idx = next_sample_idx return probs.div_(q).argmax(dim=1).view(-1, num_samples) @@ -342,11 +453,15 @@ def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -354,6 +469,15 @@ def _sample_with_torch( sample_metadata = {} multinomial_samples = {} + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.empty(logprobs.shape[0], + 1, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: @@ -361,56 +485,74 @@ def _sample_with_torch( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices) + + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[sample_indices.long()], + greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of_in_batch = 1 - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, - "generators": sampling_metadata.generators, } + multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of_in_batch, + probs[long_sample_indices], max_best_of_in_batch, **seeded_args) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = multinomial_samples[sampling_type] + elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") # GPU<->CPU sync happens in the loop below. - + # This also converts the sample output to Python objects. for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ - sampling_type] + (seq_group_id, seq_groups) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, greedy_samples) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, is_prompts, + sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] - return sample_results + return sample_results, sampled_token_ids_tensor def _sample_with_triton_kernel( @@ -418,11 +560,13 @@ def _sample_with_triton_kernel( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> SampleResultType: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -438,17 +582,16 @@ def _sample_with_triton_kernel( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices, + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups, + sample_indices, sampled_token_indices) if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, SamplingType.RANDOM_SEED): - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) elif sampling_type == SamplingType.BEAM: @@ -472,34 +615,50 @@ def _sample_with_triton_kernel( for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - (seq_group_ids, seq_groups, is_prompts, sample_indices, + (seq_group_id, seq_groups, sample_indices, sampled_token_indices) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample( seq_groups, sampled_tokens[sampled_token_indices][:, 0]) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample( - seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + seq_groups, sampled_tokens[sampled_token_indices]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] return sample_results def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - return _sample_with_torch(probs, logprobs, sampling_metadata) + probs: torch.Tensor, logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, modify_greedy_probs: bool +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) # TODO: Enable once Triton kernel & associated code is faster. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, @@ -522,159 +681,339 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, - sample_results: List[Tuple[List[int], List[int]]], -) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ - int, float]]]]: - # Prepare query indices - batched_logprobs_query_seq_indices: List[int] = [] - batched_logprobs_query_token_indices: List[int] = [] - # at least get one logprob for each token + sample_results: SampleResultType, +) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: + """Return sample lobprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: List[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: List[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. largest_num_logprobs = 1 - sample_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_parent_seqs = len(seq_ids) - if (i < sampling_metadata.num_prompts + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): largest_num_logprobs = max(largest_num_logprobs, sampling_params.prompt_logprobs) - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - batched_logprobs_query_seq_indices.extend( - sample_idx + j for j in range(prompt_len - 1)) - batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) - sample_idx += prompt_len - 1 - batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - - batched_logprobs_query_seq_indices_gpu = torch.tensor( - batched_logprobs_query_seq_indices, device=logprobs.device) - batched_logprobs_query_token_indices_gpu = torch.tensor( - batched_logprobs_query_token_indices, device=logprobs.device) - - # Batched query for logprobs of selected token - batched_logprobs_query_result = logprobs[[ - batched_logprobs_query_seq_indices_gpu, - batched_logprobs_query_token_indices_gpu + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) + + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) + + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None + return [empty_prompt_logprob], [empty_sampled_logprob] + + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] - batched_ranks_query_result = _get_ranks( - logprobs[batched_logprobs_query_seq_indices_gpu], - batched_logprobs_query_token_indices_gpu) - - # Batched query for logprobs of topk tokens + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). if largest_num_logprobs > 0: top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, dim=-1) - top_logprobs = top_logprobs.cpu() - top_token_ids = top_token_ids.cpu() else: top_logprobs, top_token_ids = None, None - batched_logprobs_query_result = batched_logprobs_query_result.cpu() - batched_ranks_query_result = batched_ranks_query_result.cpu() - - # Gather results - result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] - result_sample_logprobs: List[SampleLogprobs] = [] - sample_idx = 0 - query_result_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + if top_logprobs is not None and top_token_ids is not None: + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') + + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 + + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) + + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if is_prompt and sampling_params.prompt_logprobs is not None: + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): + # Calculate the prompt logprob of the real prompt tokens. + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + token_id: (selected_logprob_items[idx], rank_items[idx]) + } - # Prompt logprobs - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - num_logprobs = sampling_params.prompt_logprobs - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - group_prompt_logprobs: PromptLogprobs = [None] - for token_id in prompt_tokens[1:]: - prompt_logprobs_dict = { - token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) - } - if num_logprobs > 0: - prompt_logprobs_dict.update( - zip( - top_token_ids[sample_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - sample_idx, :num_logprobs].tolist(), - range(1, num_logprobs + 1)))) - group_prompt_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in prompt_logprobs_dict.items() + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) }) - sample_idx += 1 - query_result_idx += 1 - result_prompt_logprobs.append(group_prompt_logprobs) - else: - result_prompt_logprobs.append(None) - - # Sample logprobs - num_logprobs = sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 - group_sample_logprobs: SampleLogprobs = [] - for next_token_id, parent_id in zip(next_token_ids, parent_ids): - sample_logprobs_dict = { - next_token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs or 0 + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, + parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: (selected_logprob_items[idx], rank_items[idx]) } - query_result_idx += 1 - if num_logprobs >= 0: - sample_logprobs_dict.update( - zip( - top_token_ids[sample_idx + - parent_id, :num_logprobs].tolist(), - zip( - top_logprobs[sample_idx + - parent_id, :num_logprobs].tolist(), - range(1, num_logprobs + 1)))) - group_sample_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in sample_logprobs_dict.items() + # Get top K logprobs. + if num_logprobs > 0: + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[top_logprob_idx + + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() }) - result_sample_logprobs.append(group_sample_logprobs) - sample_idx += len(seq_ids) - return result_prompt_logprobs, result_sample_logprobs + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + # NOTE: logprobs are not modified so they can be returned to the user. + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 def _build_sampler_output( - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], ) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): - seq_ids, _ = seq_group + seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] for parent_id, next_token_id, logprobs in zip(parent_ids, @@ -683,5 +1022,52 @@ def _build_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput(outputs=sampler_output) + CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + ) + + +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + query_len = seq_group.query_len + assert query_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + query_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 096945f41bab6..a3d299c05caef 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -34,7 +34,7 @@ def load_best_sols(self): self.bestsols = pd.read_csv(self.tune_path) def create_ds(self): - df = self.bestsols + df: pd.DataFrame = self.bestsols solds = {} for i in range(len(df)): ds = df.iloc[i] @@ -75,6 +75,7 @@ def mm(self, inp, weights): #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: + if (self.save_gemm == 1): #print('>>>Tgemm Default',inp_view.shape, # inp.shape,weights.shape,soltype,solidx) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 73bbfac33ed13..4585b1679cb5c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -4,11 +4,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -107,6 +105,14 @@ def forward(self, input_): output = tensor_model_parallel_all_reduce(output_parallel) return output + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py deleted file mode 100644 index cb01edc25fe7e..0000000000000 --- a/vllm/model_executor/model_loader.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Utilities for selecting and loading models.""" -import contextlib -from typing import Tuple, Type - -import torch -import torch.nn as nn - -from vllm.config import DeviceConfig, ModelConfig -from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.llava import LlavaForConditionalGeneration -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] - - -@contextlib.contextmanager -def _set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - -def _get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if (model_config.quantization is not None - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def get_architecture_class_name(model_config: ModelConfig) -> str: - return _get_model_architecture(model_config)[1] - - -def get_model(model_config: ModelConfig, device_config: DeviceConfig, - **kwargs) -> nn.Module: - lora_config = kwargs.get("lora_config", None) - vision_language_config = kwargs.get("vision_language_config", None) - model_class = _get_model_architecture(model_config)[0] - - # Get the (maybe quantized) linear method. - linear_method = None - if model_config.quantization is not None: - quant_config = get_quant_config(model_config) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - linear_method = quant_config.get_linear_method() - - with _set_default_torch_dtype(model_config.dtype): - # Create a model instance. - # The weights will be initialized as empty tensors. - with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): - model = model_class(model_config.hf_config, linear_method, - lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - else: - if model_class not in _VISION_MODEL_CLASSES: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config, - vision_language_config, linear_method) - if model_config.load_format == "dummy": - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - else: - load_ammo_quantized_weights = getattr(model, "load_ammo_quantized_weights", None) - if model_config.quantization is not None and load_ammo_quantized_weights is not None: - load_ammo_quantized_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision, - quant_config) - # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) - return model.eval() diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000000..e3e32d61ab04d --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,33 @@ +from typing import Optional + +from torch import nn + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.model_executor.model_loader.loader import (BaseModelLoader, + get_model_loader) +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture) + + +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model(model_config=model_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config) + + +__all__ = [ + "get_model", "get_model_loader", "BaseModelLoader", + "get_architecture_class_name", "get_model_architecture" +] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py new file mode 100644 index 0000000000000..aa143c65d82b9 --- /dev/null +++ b/vllm/model_executor/model_loader/loader.py @@ -0,0 +1,561 @@ +# ruff: noqa: SIM117 +import collections +import copy +import glob +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Tuple, Type + +import huggingface_hub +import torch +from torch import nn + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, + tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models.vlm_base import VisionLanguageModelBase + +logger = init_logger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif issubclass(model_class, VisionLanguageModelBase): + if vision_language_config is None: + raise ValueError("Provide `image_input_type` and other vision " + "related configurations through LLM entrypoint " + "or engine arguments.") + + extra_kwargs["vision_language_config"] = vision_language_config + return extra_kwargs + + +def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class(config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + **_get_model_initialization_kwargs( + model_class, lora_config, vision_language_config)) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, + revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + return np_cache_weights_iterator(model_name_or_path, + self.load_config.download_dir, + hf_folder, hf_weights_files) + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + return pt_weights_iterator(hf_weights_files) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + if model_config.quantization == 'fp8' and model_config.quantization_param_path is not None: + model.load_quantized_weights( + safetensors_weights_iterator([model_config.model + model_config.quantization_param_path]) + ) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/tensorize_vllm_model.py) This should still be faster than + default HuggingFace loading, but will be slower than loading a + vLLM-tensorized model. + """ + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/tensorize_vllm_model.py example script + for serializing vLLM models.""" + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config( + model_config, self.load_config) + extra_kwargs = _get_model_initialization_kwargs( + model_class, lora_config, vision_language_config) + extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + self._verify_config(model_config, parallel_config) + + if is_vllm_tensorized(self.tensorizer_config): + return self._load_model_serialized(model_config, device_config, + lora_config, + vision_language_config, + cache_config) + return self._load_model_serialized_cpu(model_config, device_config, + lora_config, + vision_language_config, + cache_config) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_states.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + from safetensors.torch import safe_open + + from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", tensor.shape, + key, param_shape) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/model_loader/neuron.py similarity index 92% rename from vllm/model_executor/neuron_model_loader.py rename to vllm/model_executor/model_loader/neuron.py index 43d17ad373b87..07e23aca6cc5f 100644 --- a/vllm/model_executor/neuron_model_loader.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Optional, Type +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -27,7 +27,7 @@ } # Models supported by Neuron. -_NEURON_SUPPORTED_MODELS = { +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { "LlamaForCausalLM": ("transformers_neuronx.llama.model", "LlamaForSampling", "LlamaForCausalLM"), "MistralForCausalLM": ("transformers_neuronx.mistral.model", @@ -43,11 +43,13 @@ def __init__( ) -> None: super().__init__() self.config = config - self.model = None self.logits_processor = LogitsProcessor(config.vocab_size, logits_as_input=True) self.sampler = Sampler() + # Lazy initialized + self.model: nn.Module + def forward( self, input_ids: torch.Tensor, @@ -74,17 +76,17 @@ def sample( def load_weights(self, model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls, hf_model_cls = ( + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) split_model_dir = f"{model_name_or_path}-split" if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): split_model_dir = model_name_or_path elif not os.path.exists(f"{model_name_or_path}-split"): - hf_model_cls = getattr(transformers, hf_model_cls) + hf_model_cls = getattr(transformers, hf_model_cls_name) from transformers_neuronx.module import save_pretrained_split hf_model = hf_model_cls.from_pretrained(model_name_or_path, @@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() -def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: +def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: if arch in _NEURON_SUPPORTED_MODELS: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py new file mode 100644 index 0000000000000..2cf4ce5f88521 --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -0,0 +1,432 @@ +import argparse +import dataclasses +import io +import os +import time +import typing +from dataclasses import dataclass +from functools import partial +from typing import Generator, Optional, Tuple, Type, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import ModelConfig, ParallelConfig +from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +tensorizer_error_msg = None + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) +except ImportError as e: + tensorizer_error_msg = str(e) + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor', 'TensorizerConfig' +] + +logger = init_logger(__name__) + + +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[Type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) # type: ignore + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Tensorizer CLI arguments""" + + # Tensorizer options arg group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=None, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + attrs = [attr.name for attr in dataclasses.fields(cls)] + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/model_loader/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + quant_config: QuantizationConfig, **extra_kwargs): + if tensorizer_error_msg is not None: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) + + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("quant_config", None) is not None: + self.quant_config = extra_kwargs["quant_config"] + else: + self.quant_config = quant_config + self.model = self._init_model() + + def _init_model(self): + assert self.tensorizer_config.hf_config is not None + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + quant_config=self.quant_config, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + start = time.perf_counter() + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker + return self.model.eval() + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[Tuple[str, torch.Tensor], None, None]: + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + if (".vllm_tensorized_marker" in deserializer): + return True + return False + + +def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module: + model = (engine.model_executor.driver_worker.model_runner.model) + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + return model + + +def serialize_vllm_model(engine: "LLMEngine", + tensorizer_config : TensorizerConfig, + encryption_key_path: Optional[str] = None) \ + -> nn.Module: + + model = get_pretensorized_vllm_model(engine) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + encryption_params = None + if encryption_key_path is not None: + encryption_params = EncryptionParams.random() + with _write_stream(encryption_key_path, + **tensorizer_args.stream_params) as stream: + stream.write(encryption_params.key) + + with _write_stream(tensorizer_args.tensorizer_uri, + **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", + str(tensorizer_args.tensorizer_uri)) + return model diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000000000..f7e0f56c1a46e --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,41 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if (model_config.quantization is not None + and model_config.quantization != "fp8" + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py new file mode 100644 index 0000000000000..53e21eba8fae3 --- /dev/null +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -0,0 +1,448 @@ +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import os +import tempfile +from collections import defaultdict +from typing import Any, Generator, Iterable, List, Optional, Tuple + +import filelock +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QuantizationConfig, + get_quantization_config) +from vllm.model_executor.layers.quantization.schema import QuantParamSchema + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), + mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if hf_quant_config is None: + compression_config = getattr(model_config.hf_config, + "compression_config", None) + if compression_config is not None: + hf_quant_config = compression_config.get("quantization_config", + None) + + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + return quant_cls.from_config(config) + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have SAFE_WEIGHTS_INDEX_NAME. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: List[str], + hf_folder: str) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as index_file: + weight_map = json.load(index_file)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + +def filter_files_not_needed_for_inference( + hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +def np_cache_weights_iterator( + model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def kv_cache_scales_loader( + filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, + model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + Keep this function in sync with the output of examples/fp8/extract_scales.py + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, + context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error("File or directory '%s' not found.", filename) + except json.JSONDecodeError: + logger.error("Error decoding JSON in file '%s'.", filename) + except Exception as e: + logger.error("An error occurred while reading '%s': %s", filename, e) + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", tp_rank) + return [] + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc970568042..a92abe6b5b8dc 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) # Architecture -> (module, class). -_MODELS = { +_GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b @@ -42,10 +42,11 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -53,9 +54,17 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), } +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -90,8 +99,8 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: "ROCm for now.") if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: logger.warning( - f"Model architecture {model_arch} is partially supported " - "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + "Model architecture %s is partially supported by ROCm: %s", + model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) module_name, model_cls_name = _MODELS[model_arch] module = importlib.import_module( @@ -106,12 +115,16 @@ def get_supported_archs() -> List[str]: def register_model(model_arch: str, model_cls: Type[nn.Module]): if model_arch in _MODELS: logger.warning( - f"Model architecture {model_arch} is already registered, " - "and will be overwritten by the new model " - f"class {model_cls.__name__}.") + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def is_embedding_model(model_arch: str) -> bool: + return model_arch in _EMBEDDING_MODELS + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000000000..313762b1353d1 --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,532 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.arctic import ArcticConfig + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + cache_config, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.layers = nn.ModuleList([ + ArcticDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + attn_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module): + + def __init__(self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + ) + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping = [] + expert_params_mapping = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a6974..babb92e7cdcef 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -19,30 +19,30 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -78,17 +78,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -111,7 +111,8 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -133,13 +134,13 @@ def __init__( self.total_num_heads, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": @@ -153,7 +154,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.rotary_emb = get_rope( self.head_dim, @@ -162,7 +164,11 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -185,7 +191,8 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -197,13 +204,14 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -244,7 +252,8 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -255,7 +264,8 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) + BaiChuanDecoderLayer(config, position_embedding, cache_config, + quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,13 +314,15 @@ def __init__( self, config, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = BaiChuanModel(config, position_embedding, linear_method) + self.quant_config = quant_config + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -340,19 +352,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": @@ -394,13 +401,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", linear_method, lora_config) + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -409,7 +419,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a9ff909090586..a29aee4cffb7d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,27 +17,28 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -71,7 +72,8 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -88,13 +90,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. @@ -108,7 +110,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -130,21 +134,20 @@ class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -159,17 +162,19 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, linear_method) + self.self_attention = BloomAttention(config, cache_config, + quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) - self.mlp = BloomMLP(config, linear_method) + self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) @@ -215,7 +220,8 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -230,7 +236,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, linear_method) + BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -263,12 +269,13 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = BloomModel(config, linear_method) + self.quant_config = quant_config + self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -298,14 +305,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if name == "lm_head.weight": continue if not name.startswith("transformer."): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4008896e48dd1..e3a5e43e23e1c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,30 +2,29 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig @@ -35,7 +34,8 @@ class GLMAttention(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -67,13 +67,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -86,12 +86,12 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -125,7 +125,7 @@ class GLMMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -136,7 +136,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) self.activation_func = SiluAndMul() @@ -146,7 +146,7 @@ def __init__( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, hidden_states): @@ -168,7 +168,8 @@ class GLMBlock(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -182,7 +183,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, linear_method) + self.self_attention = GLMAttention(config, cache_config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -190,7 +191,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config, linear_method) + self.mlp = GLMMLP(config, quant_config) def forward( self, @@ -238,7 +239,8 @@ class GLMTransformer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -247,8 +249,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList( - [GLMBlock(config, linear_method) for i in range(self.num_layers)]) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -283,7 +287,8 @@ class ChatGLMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -293,7 +298,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, linear_method) + self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -335,13 +340,16 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config - self.linear_method = linear_method - self.transformer = ChatGLMModel(config, linear_method) + self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) + self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() @@ -371,14 +379,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 29ba3844eb11d..84786921ce1b4 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.utils.checkpoint @@ -29,25 +29,38 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + class LayerNorm(nn.Module): def __init__(self, param_shape=None, eps=1e-5): @@ -57,14 +70,9 @@ def __init__(self, param_shape=None, eps=1e-5): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - - mean) * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.to(torch.float32) * hidden_states - return hidden_states.to(input_dtype), residuals + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() @@ -85,7 +93,7 @@ class CohereMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -95,13 +103,13 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -117,7 +125,8 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -152,13 +161,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -168,12 +177,12 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), @@ -212,13 +221,16 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, linear_method=linear_method) + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) - self.mlp = CohereMLP(config, linear_method=linear_method) + self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) @@ -251,7 +263,8 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -259,7 +272,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, linear_method=linear_method) + CohereDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -292,14 +305,15 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, linear_method) + self.model = CohereModel(config, cache_config, quant_config) self.sampler = Sampler() @torch.no_grad() @@ -328,13 +342,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -345,8 +353,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 14c0fece69214..8ff19a2015e0f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,28 +1,28 @@ # coding=utf-8 -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -46,7 +46,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=params_dtype, - linear_method=None, + quant_config=None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -65,7 +65,7 @@ class DbrxExperts(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -167,7 +167,8 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -185,13 +186,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -217,12 +218,12 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -246,11 +247,11 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, linear_method) + self.attn = DbrxAttention(config, quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -280,11 +281,13 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) - self.ffn = DbrxExperts(config, linear_method) + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) + self.ffn = DbrxExperts(config, quant_config) def forward( self, @@ -309,15 +312,18 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + DbrxBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -350,13 +356,14 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, linear_method) + self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, @@ -392,20 +399,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [( "ws" if weight_name in ["w1", "v1"] else "w2s", f"experts.mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index abf4a462871b0..e293ee491908d 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -23,16 +23,16 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch from transformers import PretrainedConfig -from vllm.config import LoRAConfig -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) class DeciLMForCausalLM(LlamaForCausalLM): @@ -56,20 +56,18 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, lora_config=lora_config) - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -79,8 +77,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4ebad..8fbda2638aaa3 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,33 +21,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -58,18 +58,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -88,7 +88,7 @@ class DeepseekMoE(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -105,7 +105,7 @@ def __init__( DeepseekMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -114,7 +114,7 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -123,7 +123,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) @@ -179,7 +179,8 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -210,14 +211,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -230,7 +231,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -253,7 +256,8 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -268,18 +272,19 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -317,10 +322,13 @@ def forward( class DeepseekModel(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -333,7 +341,8 @@ def __init__( self.layers = nn.ModuleList([ DeepseekDecoderLayer(config, layer_idx, - linear_method=linear_method) + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -361,12 +370,13 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = DeepseekModel(config, linear_method) + self.quant_config = quant_config + self.model = DeepseekModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -396,11 +406,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -411,12 +417,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 77c19b227d213..9618652f70d23 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -19,7 +19,7 @@ """PyTorch Falcon model.""" import math -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -27,23 +27,23 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -78,7 +78,8 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -117,7 +118,7 @@ def __init__( self.total_num_kv_heads, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -131,7 +132,7 @@ def __init__( self.hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, reduce_results=self.reduce_row_parallel_results) self.use_rotary = config.rotary @@ -152,7 +153,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + quant_config=quant_config) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -164,12 +166,15 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -194,7 +199,7 @@ class FalconMLP(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -203,8 +208,7 @@ def __init__( 4 * hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method) - quant_config = getattr(linear_method, "quant_config", None) + quant_config=quant_config) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -214,7 +218,7 @@ def __init__( bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - linear_method=linear_method) + quant_config=quant_config) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -231,27 +235,37 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, linear_method) - self.mlp = FalconMLP(config, linear_method) + self.self_attention = FalconAttention(config, cache_config, + quant_config) + self.mlp = FalconMLP(config, quant_config) self.config = config - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -265,7 +279,7 @@ def forward( ) -> torch.Tensor: residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -288,6 +302,10 @@ def forward( residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + # MLP. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) if self.reduce_row_parallel_results and mlp_bias is not None: @@ -313,7 +331,8 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -329,7 +348,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, linear_method) + FalconDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -361,13 +380,27 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = FalconModel(config, linear_method) - self.lm_head_weight = self.transformer.word_embeddings.weight + self.quant_config = quant_config + self.transformer = FalconModel(config, cache_config, quant_config) + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head_weight = self.transformer.word_embeddings.weight + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -400,11 +433,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -414,10 +443,9 @@ def load_weights(self, total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if name == "lm_head.weight": - # Falcon uses tied embeddings. + for name, loaded_weight in weights: + if name == "lm_head.weight" and self.tie_word_embeddings: + # Falcon uses tied embeddings except Falcon-11b. continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3e..27dda00b66af4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,31 +15,30 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput logger = init_logger(__name__) @@ -57,10 +56,10 @@ def _get_gemma_act_fn( "in the config JSON file when it was initially released. " "Changing the activation function to approximate GeLU " "(`gelu_pytorch_tanh`). If you want to use the legacy " - f"`{hidden_act}`, edit the config JSON to set " - f"`hidden_activation={hidden_act}` instead of `hidden_act`. " + "`%s`, edit the config JSON to set " + "`hidden_activation=%s` instead of `hidden_act`. " "See https://github.com/huggingface/transformers/pull/29402 " - "for more details.") + "for more details.", hidden_act, hidden_act) return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu_pytorch_tanh": return GeluAndMul(approximate="tanh") @@ -79,17 +78,17 @@ def __init__( intermediate_size: int, hidden_act: Optional[str] = None, hidden_activation: Optional[str] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -108,7 +107,8 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None) -> None: + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -137,13 +137,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -156,7 +156,9 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -178,7 +180,8 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -189,14 +192,15 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_activation=getattr(config, "hidden_activation", None), - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -237,7 +241,8 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -247,7 +252,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, linear_method) + GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -310,14 +315,15 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config # Unused. super().__init__() self.config = config - self.linear_method = linear_method - self.model = GemmaModel(config, linear_method) + self.quant_config = quant_config + self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -347,11 +353,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -362,8 +364,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be5..cc83f6eb6d94d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,27 +17,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -46,7 +46,8 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -63,15 +64,19 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -92,7 +97,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -100,15 +105,14 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -124,7 +128,8 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -132,9 +137,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, linear_method) + self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, linear_method) + self.mlp = GPT2MLP(inner_dim, config, quant_config) def forward( self, @@ -165,7 +170,8 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -176,7 +182,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, linear_method) + GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -205,12 +211,13 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPT2Model(config, linear_method) + self.quant_config = quant_config + self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -240,14 +247,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 07c647c2e1c41..69b75763e9a3d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,27 +18,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -47,7 +47,8 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -74,19 +75,21 @@ def __init__( total_num_heads, total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -113,7 +116,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -121,15 +124,14 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -145,7 +147,8 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -153,9 +156,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, linear_method) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, linear_method) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) def forward( self, @@ -186,18 +189,24 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, linear_method) + GPTBigCodeBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -222,18 +231,35 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPTBigCodeModel(config, linear_method) + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(config, cache_config, quant_config, + lora_config) self.lm_head_weight = self.transformer.wte.weight - self.logits_processor = LogitsProcessor(config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( @@ -261,14 +287,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if ".attn.bias" in name: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 94048efe48420..47fd5788a4c35 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,28 +16,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -46,7 +46,8 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -58,13 +59,13 @@ def __init__( self.head_size, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -84,7 +85,11 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -107,21 +112,20 @@ def __init__( self, intermediate_size: int, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -137,14 +141,15 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, linear_method) - self.mlp = GPTJMLP(inner_dim, config, linear_method) + self.attn = GPTJAttention(config, cache_config, quant_config) + self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( self, @@ -171,7 +176,8 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -180,8 +186,10 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList( - [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([ + GPTJBlock(config, cache_config, quant_config) + for _ in range(config.n_layer) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -209,13 +217,14 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, linear_method) + self.transformer = GPTJModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, @@ -249,11 +258,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -263,8 +268,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index a5b5d717d9846..eb0fcc8f26a58 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,28 +16,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -46,7 +46,8 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -65,13 +66,13 @@ def __init__( self.head_size, self.total_num_heads, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) @@ -85,7 +86,11 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -107,20 +112,19 @@ class GPTNeoXMLP(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -136,7 +140,8 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -144,8 +149,8 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, linear_method) - self.mlp = GPTNeoXMLP(config, linear_method) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) + self.mlp = GPTNeoXMLP(config, quant_config) def forward( self, @@ -184,7 +189,8 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -194,7 +200,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, linear_method) + GPTNeoXLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -225,12 +231,13 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.gpt_neox = GPTNeoXModel(config, linear_method) + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, @@ -263,14 +270,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index bdb48bf21042e..e75c567f589c8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,27 +1,27 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -32,17 +32,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -65,7 +65,8 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -96,13 +97,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -115,7 +116,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -137,7 +140,8 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -152,13 +156,14 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -197,7 +202,8 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -208,7 +214,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) + InternLMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -240,12 +246,13 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = InternLM2Model(config, linear_method) + self.quant_config = quant_config + self.model = InternLM2Model(config, cache_config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -275,19 +282,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 12fc9dbd50732..869b8fc91fd64 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,25 +20,26 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import JAISConfig @@ -69,7 +70,8 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -89,13 +91,13 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) tp_rank = get_tensor_model_parallel_rank() @@ -103,12 +105,12 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -129,7 +131,7 @@ def __init__( self, intermediate_size: int, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -138,19 +140,19 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_fc2 = (ColumnParallelLinear( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) if self.swiglu else None) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.act = SwiGLUActivation() @@ -170,7 +172,8 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -178,9 +181,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, linear_method) + self.attn = JAISAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = JAISMLP(inner_dim, config, linear_method) + self.mlp = JAISMLP(inner_dim, config, quant_config) def forward( self, @@ -211,7 +214,8 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -228,7 +232,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, linear_method) + JAISBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -262,12 +266,13 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = JAISModel(config, linear_method) + self.quant_config = quant_config + self.transformer = JAISModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale @@ -303,16 +308,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f021be7ee4822..2b8d3573f45cf 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -29,27 +29,26 @@ from vllm import _custom_C from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator, - kv_cache_scales_loader) from vllm.sequence import SamplerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, print_warning_once class LlamaMLP(nn.Module): @@ -59,17 +58,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -101,9 +102,9 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -128,28 +129,19 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # This will be overwritten by model initialization if we are using it. - # N.B. currently we only support per tensor scalar scaling factors - # & only applicable to ROCm (AMD GPU). - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - self.kv_scale = 1.0 - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -163,7 +155,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -175,8 +168,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - self.kv_scale) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -186,15 +178,19 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( @@ -207,15 +203,16 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=attention_bias, - sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -256,7 +253,8 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -272,8 +270,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -337,13 +333,16 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -356,6 +355,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -387,22 +388,17 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -425,20 +421,25 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - def load_ammo_quantized_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - quant_config: Optional[QuantizationConfig] = None): - import os.path - #import json + def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) #with open("/projects/a.txt", "r") as f: # j = json.load(f) @@ -453,25 +454,15 @@ def load_ammo_quantized_weights( ("self_attn.o_proj", "attention.dense"), ("self_attn.qkv_proj", "attention.qkv"), ] - tp_rank = get_tensor_model_parallel_rank() - quantized_filename: str = quant_config.quantized_weights_path - #quantized_filename = "quantized/osf/rank0.safetensors" - if not quantized_filename.startswith('/'): - quantized_filename = os.path.join(model_name_or_path, - quantized_filename) - #if not os.path.isdir(quantized_filename): - # quantized_filename = os.path.dirname(quantized_filename) - for name, loaded_weight in hf_model_weights_iterator( - quantized_filename, cache_dir, 'safetensors', revision): + for name, loaded_weight in weights: #print(name) name = name.replace('transformer', 'model') name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') - #if "output_scaling_factor" in name: - # print(f"{name} {loaded_weight}") - #if "kv_cache_scaling_factor" in name: - # print(f"KVK: {name} {loaded_weight}") - #if "lm_head" in name: - # print(f"LM {name}: {loaded_weight}") + loaded_weight = loaded_weight.to("cuda") + if loaded_weight.dtype == torch.int8: + loaded_weight[loaded_weight == -128] = 0 + assert loaded_weight.is_contiguous + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) for (param_name, weight_name, shard_id) in quant_shards: if weight_name not in name: continue @@ -480,23 +471,8 @@ def load_ammo_quantized_weights( if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - if "activation_scaling_factor" in name: - asf_loader = getattr(param, "asf_loader", - default_weight_loader) - asf_loader(param, loaded_weight, shard_id) - elif "weights_scaling_factor" in name: - wsf_loader = getattr(param, "wsf_loader", - default_weight_loader) - #print(f"{name} {shard_id} {loaded_weight}") - wsf_loader(param, loaded_weight, shard_id) - elif "output_scaling_factor" in name: - osf_loader = getattr(param, "osf_loader", - default_weight_loader) - #print(f"{name} {shard_id} {loaded_weight}") - osf_loader(param, loaded_weight, shard_id) - else: - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. @@ -507,20 +483,13 @@ def load_ammo_quantized_weights( param = params_dict[name] if "activation_scaling_factor" in name or "weights_scaling_factor" in name: param.data.copy_(loaded_weight) - if "output_scaling_factor" in name: - param.data.copy_(1 / loaded_weight.item()) + elif "output_scaling_factor" in name: + param.data.copy_(loaded_weight) else: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) break - rescale = [ - "mlp.gate_up_proj.weight" - ] - for name, param in params_dict.items(): - for suffix in rescale: - if name.endswith(suffix): - param.rescale() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -541,7 +510,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.kv_scale = scaling_factor + layer_self_attn.attn._kv_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py new file mode 100644 index 0000000000000..8f1c77da50d96 --- /dev/null +++ b/vllm/model_executor/models/llama_embedding.py @@ -0,0 +1,87 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class LlamaEmbeddingModel(nn.Module): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c2571d0893c8d..fbd7638097286 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch from torch import nn @@ -7,18 +7,20 @@ from transformers import CLIPVisionModel, LlavaConfig from vllm.attention import AttentionMetadata -from vllm.config import VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from .vlm_base import VisionLanguageModelBase + _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", @@ -40,7 +42,7 @@ def __init__(self, vision_hidden_size: int, text_hidden_size: int, text_hidden_size, bias=True) - def forward(self, image_features): + def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -50,28 +52,46 @@ def forward(self, image_features): def _merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: torch.Tensor, - image_token_id: int): + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) - inputs_embeds[mask] = vision_embeddings.view(-1, + + image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1] + if mask.sum() != image_feature_size: + raise ValueError(f"image_feature_size should be {image_feature_size}, " + f"but found: {mask.sum()}") + + inputs_embeds[mask] = vision_embeddings.view(image_feature_size, vision_embeddings.shape[-1]) + return inputs_embeds + + +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +class LlavaImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, image_feature_size, hidden_size)""" + + +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] + -class LlavaForConditionalGeneration(nn.Module): +class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, - config: "LlavaConfig", + config: LlavaConfig, vision_language_config: VisionLanguageConfig, - linear_method: Optional["LinearMethodBase"] = None) -> None: - super().__init__() - self.config = config - - self.vision_language_config = vision_language_config + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) - assert self.vision_language_config, ( - "Provide `image_input_type` and other vision " - "related configurations through LLM entrypoint " - "or engine arguments.") + self.config = config if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): @@ -84,8 +104,9 @@ def __init__(self, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.linear_method = linear_method - self.language_model = LlamaModel(config.text_config, linear_method) + self.quant_config = quant_config + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -96,14 +117,96 @@ def __init__(self, config.vocab_size, logit_scale) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None - ) -> SamplerOutput: # noqa: E501 + def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != list( + self.vision_language_config.image_input_shape[1:]): + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"{self.vision_language_config.image_input_shape[1:]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _parse_and_validate_image_input( + self, data: object) -> Optional[LlavaImageInputs]: + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if data is None: + return None + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image pixel vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_image_data(data), + ) + elif expected_input_type == ImageInputType.IMAGE_FEATURES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image feature vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImageFeatureInputs( + type="image_features", + data=self._validate_image_data(data), + ) + + return None + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, + inputs: LlavaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, + image_input: LlavaImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + return self.multi_modal_projector(image_features) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + image_input: Optional[torch.Tensor] = None) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -140,42 +243,20 @@ def forward( For PIXEL_VALUES, expecting [1, 3, 336, 336]. For IMAGE_FEATURES, expecting [1, 576, 1024]. """ - if image_input is not None: - if list(image_input.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): - raise ValueError( - f"The expected image tensor shape is batch dimension " - f"plus " - f"{self.vision_language_config.image_input_shape[1:]}." - f" You supplied {image_input.shape}. " - f"If you are using vLLM's entrypoint, make sure your " - f"supplied image input is consistent with " - f"image_input_shape in engine args.") - if self.vision_tower is not None: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = self.vision_tower(image_input, - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if self.config.vision_feature_select_strategy == "default": - image_features = image_features[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - image_features = image_features - else: - raise ValueError( - f"Unexpected select feature strategy: " - f"{self.config.vision_feature_select_strategy}") - else: - image_features = image_input - vision_embeddings = self.multi_modal_projector(image_features) + parsed_image_input = self._parse_and_validate_image_input(image_input) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input(parsed_image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - _merge_vision_embeddings( + + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) + input_ids = None else: inputs_embeds = None + hidden_states = self.language_model(input_ids, positions, kv_caches, @@ -198,11 +279,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # only doing this for language model part for now. stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -213,8 +290,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 99d1b4eb97bb8..59fbf8e1b35f2 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -22,34 +22,33 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -86,7 +85,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, @@ -149,17 +148,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -182,7 +181,8 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -213,13 +213,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -235,7 +235,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -260,7 +262,8 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -276,7 +279,8 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: @@ -284,7 +288,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) else: self.mlp = MiniCPMMoE(num_experts=config.num_experts, @@ -331,7 +335,8 @@ class MiniCPMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -347,7 +352,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, linear_method) + MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -414,15 +419,17 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) - self.linear_method = linear_method + self.quant_config = quant_config self.model = MiniCPMModel(config, - linear_method, + cache_config, + quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: @@ -473,11 +480,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -494,8 +497,7 @@ def load_weights(self, for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 429bc8109b9f8..2f4237339486e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,34 +21,36 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import MixtralConfig +from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class MixtralMoE(nn.Module): @@ -68,6 +70,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -75,37 +78,90 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(quant_config, Fp8Config) if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) + + if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn - self.ws = nn.Parameter( + self.w13_weight = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( + dtype=params_dtype)) + self.w2_weight = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) + dtype=params_dtype)) - set_weight_attrs(self.ws, { + set_weight_attrs(self.w13_weight, { "weight_loader": self.weight_loader, }) - set_weight_attrs(self.w2s, { + set_weight_attrs(self.w2_weight, { "weight_loader": self.weight_loader, }) + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() @@ -119,6 +175,49 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight + + def process_weights_after_loading(self): + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -126,12 +225,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, + self.w13_weight, + self.w2_weight, router_logits, self.top_k, renormalize=True, - inplace=True) + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -142,14 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -171,7 +277,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -179,13 +284,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -194,13 +299,12 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -222,7 +326,8 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -234,13 +339,14 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, - linear_method=linear_method) + cache_config=cache_config, + quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -280,7 +386,8 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -296,7 +403,9 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -320,6 +429,8 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -344,14 +455,15 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method self.model = MixtralModel(config, - linear_method, + cache_config, + quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -394,11 +506,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -407,20 +515,30 @@ def load_weights(self, ] expert_params_mapping = [ + # These are the weight scales for the experts # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -451,7 +569,26 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 75f86bc134ee3..1894c05e167d6 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import numpy as np import torch @@ -30,23 +30,23 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -57,7 +57,7 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.num_experts = num_experts @@ -67,15 +67,15 @@ def __init__( self.w1 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = ReplicatedLinear(self.ffn_dim, self.hidden_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w3 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) # TODO: Use vllm's SiluAndMul self.act_fn = nn.SiLU() @@ -94,7 +94,7 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -117,14 +117,14 @@ def __init__( MixtralMLP(self.num_total_experts, config.hidden_size, config.intermediate_size, - linear_method=linear_method) + quant_config=quant_config) if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - linear_method=None) + quant_config=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -158,14 +158,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -187,7 +189,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -195,13 +196,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -210,13 +211,12 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -238,7 +238,8 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -250,10 +251,10 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, - linear_method=linear_method) + cache_config=cache_config, + quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -293,7 +294,8 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -304,7 +306,9 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -328,16 +332,18 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) + self.quant_config = quant_config + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -367,11 +373,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -380,12 +382,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index a39f94359a948..5f9e4d86f3cd8 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,26 +1,27 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -43,7 +44,8 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -66,7 +68,7 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -75,7 +77,7 @@ def __init__( self.d_model, self.d_model, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -107,7 +109,9 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -134,7 +138,7 @@ class MPTMLP(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model @@ -144,15 +148,14 @@ def __init__( hidden_size, intermediate_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn("gelu", quant_config, intermediate_size) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -167,14 +170,15 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, linear_method) + self.attn = MPTAttention(config, cache_config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MPTMLP(config, linear_method) + self.ffn = MPTMLP(config, quant_config) def forward( self, @@ -202,7 +206,8 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() assert config.embedding_fraction == 1.0 @@ -212,8 +217,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [MPTBlock(config, linear_method) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + MPTBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -247,14 +254,15 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config assert config.tie_word_embeddings - self.linear_method = linear_method + self.quant_config = quant_config - self.transformer = MPTModel(config, linear_method) + self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -284,14 +292,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 611a48a9aad2b..39270f71ec46f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -1,65 +1,48 @@ # coding=utf-8 # Adapted from -# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and -# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py -# Copyright 2023 The vLLM team. -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. # -# BSD 3-Clause License +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: +# http://www.apache.org/licenses/LICENSE-2.0 # -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch -# this model must need this dependency -from hf_olmo import OLMoConfig from torch import nn +from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -72,56 +55,56 @@ class OlmoAttention(nn.Module): def __init__( self, - config: OLMoConfig, - linear_method: Optional[LinearMethodBase] = None, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.hidden_size = config.d_model - assert config.d_model % config.n_heads == 0 + self.hidden_size = config.hidden_size tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) - self.total_num_heads = self.config.n_heads + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv - # Layer norms. - self.attn_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) # Attention input projection. Projects x -> (q, k, v) - self.att_proj = QKVParallelLinear( - config.d_model, + self.qkv_proj = QKVParallelLinear( + self.hidden_size, self.head_dim, self.total_num_heads, - bias=config.include_bias, - linear_method=linear_method, + bias=config.attention_bias, + quant_config=quant_config, ) # Rotary embeddings. - if self.config.rope: - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) # Attention output projection. - self.attn_out = RowParallelLinear( - config.d_model, - config.d_model, - bias=config.include_bias, - linear_method=linear_method, + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, ) def forward( @@ -131,13 +114,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - hidden_states = self.attn_norm(hidden_states) - qkv, _ = self.att_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.config.rope: - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.attn_out(attn_output) + output, _ = self.o_proj(attn_output) return output @@ -150,57 +133,44 @@ class OlmoMLP(nn.Module): def __init__( self, - config: OLMoConfig, - linear_method: Optional[LinearMethodBase] = None, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size - is not None else config.mlp_ratio * config.d_model) - - # Layer norms. - self.ff_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size # Feed-forward input projection. - self.ff_proj = MergedColumnParallelLinear( - config.d_model, - [self.hidden_size // 2] * 2, - bias=config.include_bias, - linear_method=linear_method, + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, ) # Activation function. - self.act = SiluAndMul() - self.act.output_multiplier = 0.5 - assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + self.act_fn = SiluAndMul() # Feed-forward output projection. - self.ff_out = RowParallelLinear( - int(self.act.output_multiplier * self.hidden_size), - config.d_model, - bias=config.include_bias, - linear_method=linear_method, + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, ) def forward( self, x: torch.Tensor, ) -> torch.Tensor: - # Add feed-forward projection. - # shape: (batch_size, seq_len, d_model) - og_x = x - x = self.ff_norm(x) - x, _ = self.ff_proj(x) - x = self.act(x) - x, _ = self.ff_out(x) - x = og_x + x - + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) return x -class OlmoBlock(nn.Module): +class OlmoDecoderLayer(nn.Module): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` @@ -208,14 +178,23 @@ class OlmoBlock(nn.Module): """ def __init__(self, - config: OLMoConfig, - linear_method: Optional[LinearMethodBase] = None): + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.attn = OlmoAttention(config, linear_method) + self.self_attn = OlmoAttention(config, cache_config, quant_config) # MLP block. - self.mlp = OlmoMLP(config, linear_method) + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) def forward( self, @@ -225,52 +204,38 @@ def forward( attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. - og_x = hidden_states - x = self.attn(positions, hidden_states, kv_cache, attn_metadata) - x = x + og_x + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual # MLP block. - hidden_states = self.mlp(x) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states return hidden_states class OlmoModel(nn.Module): def __init__(self, - config: OLMoConfig, - linear_method: Optional[LinearMethodBase] = None): + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.transformer = nn.ModuleDict( - dict( - wte=VocabParallelEmbedding( - config.embedding_size or config.vocab_size, - config.d_model, - ), - ln_f=nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False), - )) - - blocks = [ - OlmoBlock(config, linear_method) for i in range(config.n_layers) - ] - if self.config.block_group_size > 1: - raise NotImplementedError("Block group size > 1 not supported yet") - else: - self.transformer.update({"blocks": nn.ModuleList(blocks)}) - - if not config.weight_tying: - self.transformer.update({ - "ff_out": - ColumnParallelLinear( - config.d_model, - config.embedding_size or config.vocab_size, - bias=config.include_bias, - linear_method=linear_method, - ) - }) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) def forward( self, @@ -284,39 +249,49 @@ def forward( """ # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - x = self.transformer.wte(input_ids) # type: ignore + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds # Apply blocks one-by-one. - for block_idx, block in enumerate(self.transformer.blocks): + for layer_idx, decoder_layer in enumerate(self.layers): # shape: (batch_size, seq_len, d_model) - x = block( + hidden_states = decoder_layer( positions, - x, - kv_caches[block_idx], + hidden_states, + kv_caches[layer_idx], attn_metadata, ) # Apply final layer norm. # shape: (batch_size, seq_len or 1, d_model) - x = self.transformer.ln_f(x) # type: ignore - return x + hidden_states = self.norm(hidden_states) + return hidden_states -class OLMoForCausalLM(nn.Module): +class OlmoForCausalLM(nn.Module): """ Extremely barebones HF model wrapper. """ def __init__(self, - config: OLMoConfig, - linear_method: Optional[LinearMethodBase] = None): + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OlmoModel(config, linear_method) - self.lm_head_weight = (self.model.transformer.wte.weight - if config.weight_tying else - self.model.transformer.ff_out.weight) + self.model = OlmoModel(config, cache_config, quant_config) + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -349,28 +324,40 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - # attention - if ".att" in name: - name = name.replace(".att", ".attn.att") - # mlp - if ".ff_proj" in name: - name = name.replace(".ff_proj", ".mlp.ff_proj") - # Reverse the weight for the MergeColumnParallelLinear - loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) - if ".ff_out" in name and "transformer.ff_out" not in name: - name = name.replace(".ff_out", ".mlp.ff_out") - # there is no bias in olmo - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c1ae1b2ae0f03..4bf59105dbabb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,28 +17,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -62,7 +62,8 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -79,17 +80,19 @@ def __init__( self.head_dim, total_num_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -109,7 +112,8 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -118,7 +122,8 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -129,16 +134,15 @@ def __init__( self.embed_dim, config.ffn_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.activation_fn = get_act_fn(config.activation_function, quant_config, config.ffn_dim) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -183,7 +187,8 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -204,7 +209,7 @@ def __init__( self.project_out = ReplicatedLinear(config.hidden_size, config.word_embed_proj_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_out = None @@ -212,7 +217,7 @@ def __init__( self.project_in = ReplicatedLinear(config.word_embed_proj_dim, config.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_in = None @@ -228,7 +233,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, linear_method) + OPTDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -261,10 +266,11 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, linear_method) + self.decoder = OPTDecoder(config, cache_config, quant_config) def forward( self, @@ -281,12 +287,13 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OPTModel(config, linear_method) + self.quant_config = quant_config + self.model = OPTModel(config, cache_config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -316,11 +323,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -328,8 +331,7 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if name.startswith("decoder."): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index ee910563b20df..133a10e6bb3e8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -4,28 +4,28 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -36,17 +36,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -69,7 +69,8 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -100,13 +101,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -119,7 +120,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -141,7 +144,8 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -156,13 +160,14 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) self.mlp = OrionMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -203,7 +208,8 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -214,7 +220,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, linear_method) + OrionDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -246,12 +252,13 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = OrionModel(config, linear_method) + self.quant_config = quant_config + self.model = OrionModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -281,11 +288,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -295,8 +298,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7d..c8e61735a9bb6 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,28 +35,28 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -64,7 +64,8 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -82,12 +83,12 @@ def __init__(self, self.head_size, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 @@ -106,7 +107,11 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -127,7 +132,7 @@ class PhiMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() n_inner = getattr(config, "n_inner", None) @@ -136,14 +141,13 @@ def __init__(self, self.fc1 = ColumnParallelLinear( config.hidden_size, n_inner, - linear_method=linear_method, + quant_config=quant_config, ) self.fc2 = RowParallelLinear( n_inner, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): @@ -157,12 +161,13 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, linear_method) - self.mlp = PhiMLP(config, linear_method) + self.self_attn = PhiAttention(config, cache_config, quant_config) + self.mlp = PhiMLP(config, quant_config) def forward( self, @@ -188,14 +193,15 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, linear_method) + PhiLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -224,15 +230,37 @@ def forward( class PhiForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "dense", + "fc1", + "fc2", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + del lora_config # Unused. super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config - self.model = PhiModel(config, linear_method) + self.model = PhiModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, @@ -266,11 +294,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -279,8 +303,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py new file mode 100644 index 0000000000000..0c5298eb6f100 --- /dev/null +++ b/vllm/model_executor/models/phi3_small.py @@ -0,0 +1,447 @@ +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + + +def load_column_parallel_weight(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + tp = get_tensor_model_parallel_world_size() + rk = get_tensor_model_parallel_rank() + assert param.size(0) * tp == loaded_weight.size(0) + s = rk * param.size(0) + e = (rk + 1) * param.size(0) + loaded_weight = loaded_weight[s:e] + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + +class HeadMajorQKVParallelLinear(QKVParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +@torch.jit.script +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +@torch.jit.script +def gegelu(input, limit: Optional[float] = None): + a_gelu, a_linear = input[..., ::2], input[..., 1::2] + if limit is not None: + a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, + a_gelu.clamp(min=None, max=limit)) + a_linear = torch.where( + torch.isinf(a_linear), + a_linear, + a_linear.clamp(min=-limit, max=limit), + ) + out_gelu = quick_gelu(a_gelu) + return out_gelu * (a_linear + 1) + + +class Phi3SmallMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + assert (self.config.hidden_act == "gegelu" + ), "Only `gegelu` is supported for the 4.7 series of models .." + self.hidden_size = config.hidden_size + self.gegelu_limit = config.gegelu_limit + self.intermediate_size = config.intermediate_size + + self.up_proj = HeadMajorColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], + bias=True, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + gate_up, _ = self.up_proj(x) + x = gegelu(gate_up) + x, _ = self.down_proj(x) + return x + + +class Phi3SmallSelfAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.sparse_block_size = config.blocksparse_block_size + self.homo_heads = config.blocksparse_homo_head_pattern + self.local_blocks = config.blocksparse_num_local_blocks + self.vert_stride = config.blocksparse_vert_stride + + assert (config.blocksparse_block_size == + config.blocksparse_triton_kernel_block_size) + + self.hidden_size = config.hidden_size + # Number of Query Heads + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + # Number of total Key Value Heads before tensor parallel + self.num_key_value_heads = config.num_key_value_heads + self.num_q_per_kv = self.num_heads // self.num_key_value_heads + if self.tp_size > 1: + assert self.num_key_value_heads % self.tp_size == 0 + self.num_kv_heads_per_partion = max( + 1, self.num_key_value_heads // self.tp_size) + self.num_heads_per_partition = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_embedding_base = config.rope_embedding_base + self.rope_position_scale = config.rope_position_scale + self.is_causal = True + + norm_factor = None + if config.mup_use_scaling: + norm_factor = self.head_dim / config.mup_attn_multiplier + else: + norm_factor = math.sqrt(self.head_dim) + self.scale = 1 / norm_factor + + self.query_key_value = HeadMajorQKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=True, + quant_config=quant_config, + ) + + self.dense = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config) + + if getattr(self.config, "rope_scaling", None) is not None: + rope_scaling = self.config.rope_scaling + for key in rope_scaling: + if isinstance(rope_scaling[key], list): + rope_scaling[key] = tuple(rope_scaling[key]) + + if "factor" not in rope_scaling: + rope_scaling["factor"] = self.rope_position_scale + else: + rope_scaling = { + "type": "linear", + "factor": self.rope_position_scale, + } + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_embedding_base, + rope_scaling=rope_scaling, + ) + + # blocksparse params + self.blocksparse_block_size = config.blocksparse_block_size + self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks + self.blocksparse_vert_stride = config.blocksparse_vert_stride + + use_dense_attn = (getattr(self.config, + "dense_attention_every_n_layers", None) + and (self.layer_idx + 1) % + self.config.dense_attention_every_n_layers == 0) + + bs_params = None + if not use_dense_attn: + bs_params = { + 'max_seqlen': self.max_position_embeddings, + 'num_heads': self.num_heads_per_partition, + "num_kv_heads": self.num_kv_heads_per_partion, + "block_size": self.sparse_block_size, + "local_blocks": self.local_blocks, + "vert_stride": self.vert_stride, + "homo_head": self.homo_heads + } + + self.attn = Attention( + self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + qkv, _ = self.query_key_value(hidden_states) + + qkv = qkv.view(qkv.shape[:-1] + + (-1, (self.num_q_per_kv + 2), self.head_dim)) + q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) + + # NOTE: this is required by RotaryEmbed, which indeed does not have to + # TODO: allow 3D QK for rotary forward + q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) + k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + output, _ = self.dense(attn_output) + + return output + + +class Phi3SmallDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3SmallSelfAttention(config, + layer_idx, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = Phi3SmallMLP(config, quant_config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi3SmallModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + self.layers = nn.ModuleList([ + Phi3SmallDecoderLayer(config, layer_idx, cache_config, + quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata = None, + ): + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Phi3SmallForCausalLM(nn.Module): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Phi3SmallModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.mup_width_multiplier = config.mup_width_multiplier + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + # tokens in tiktoken but not used + if hasattr(config, 'dummy_token_indices'): + device = self.lm_head.weight.device + self.register_buffer('dummy_token_indices', + torch.LongTensor( + config.dummy_token_indices).to(device), + persistent=False) + else: + self.dummy_token_indices = None + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + if self.dummy_token_indices is not None and logits is not None: + logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + return logits + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + output_hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + output_hidden_states = output_hidden_states + return output_hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits / self.mup_width_multiplier, + sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a63b9c8d63d13..d22ea6b79de0f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,29 +4,29 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -37,17 +37,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -69,7 +69,8 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -85,13 +86,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.scaling = self.head_dim**-0.5 @@ -102,7 +103,11 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -124,7 +129,8 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -136,13 +142,14 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, - linear_method=linear_method) + cache_config=cache_config, + quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, - linear_method=linear_method) + quant_config=quant_config) def forward( self, @@ -176,7 +183,8 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -187,7 +195,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, linear_method) + QWenBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -219,12 +227,13 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = QWenModel(config, linear_method) + self.quant_config = quant_config + self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -254,19 +263,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b9..9a4829a27873e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -22,30 +22,29 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -56,17 +55,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -87,9 +86,9 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - use_sliding_window: bool = False, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -111,7 +110,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -119,13 +117,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -133,12 +131,14 @@ def __init__(self, rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -160,29 +160,28 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) - use_sliding_window = (config.use_sliding_window - and layer_idx < config.max_window_layers) + rope_scaling = getattr(config, "rope_scaling", None) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - use_sliding_window=use_sliding_window, - linear_method=linear_method, - sliding_window=config.sliding_window) + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -223,7 +222,8 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -235,8 +235,8 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, linear_method) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -288,14 +288,27 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2Model(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight @@ -332,11 +345,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -346,8 +355,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6b4a74198fd52..564536f2dd248 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -30,26 +30,26 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -60,18 +60,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -90,7 +90,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -107,7 +107,7 @@ def __init__( Qwen2MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -116,13 +116,13 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) else: @@ -188,7 +188,8 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -219,14 +220,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -239,7 +240,9 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -262,7 +265,8 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -277,18 +281,20 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + cache_config=cache_config, + quant_config=quant_config, ) - if (config.num_experts is not None - and (layer_idx + 1) % config.decoder_sparse_step == 0): + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, - linear_method=linear_method) + quant_config=quant_config) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -329,7 +335,8 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -342,7 +349,8 @@ def __init__( self.layers = nn.ModuleList([ Qwen2MoeDecoderLayer(config, layer_idx, - linear_method=linear_method) + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -367,15 +375,18 @@ def forward( class Qwen2MoeForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2MoeModel(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -405,11 +416,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -420,12 +427,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -439,6 +441,9 @@ def load_weights(self, if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -451,6 +456,9 @@ def load_weights(self, if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index b83637fd50dc7..a6ed3800bed0f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -19,28 +19,28 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -48,7 +48,7 @@ class StablelmMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -56,7 +56,7 @@ def __init__(self, self.gate_up_proj = MergedColumnParallelLinear( config.hidden_size, [config.intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False) @@ -73,7 +73,8 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -111,11 +112,11 @@ def __init__(self, self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, - linear_method=linear_method) + quant_config=quant_config) self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, @@ -125,7 +126,9 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_key_value_heads) + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -147,11 +150,12 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, linear_method) + self.self_attn = StablelmAttention(config, cache_config, quant_config) + self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -189,14 +193,15 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, linear_method) + StablelmDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -228,12 +233,13 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = StableLMEpochModel(config, linear_method) + self.quant_config = quant_config + self.model = StableLMEpochModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -263,11 +269,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -277,8 +279,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6ef..4324bf50d4ad1 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,28 +18,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -47,7 +47,8 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -73,7 +74,6 @@ def __init__(self, self.rope_theta = config.rope_theta self.max_position_embeddings = config.max_position_embeddings self.use_bias = config.use_bias - self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( self.hidden_size, @@ -81,13 +81,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -96,13 +96,12 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -123,21 +122,20 @@ class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -152,12 +150,14 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Starcoder2Attention(config, - linear_method=linear_method) - self.mlp = Starcoder2MLP(config, linear_method=linear_method) + cache_config, + quant_config=quant_config) + self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -194,7 +194,8 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -204,7 +205,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, linear_method=linear_method) + Starcoder2DecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -229,10 +232,13 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, linear_method=linear_method) + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: @@ -275,11 +281,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -288,8 +290,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py new file mode 100644 index 0000000000000..eb0aa96e50d59 --- /dev/null +++ b/vllm/model_executor/models/vlm_base.py @@ -0,0 +1,12 @@ +from torch import nn + +from vllm.config import VisionLanguageConfig + + +class VisionLanguageModelBase(nn.Module): + """Base class for all vision language models (VLMs).""" + + def __init__(self, vision_language_config: VisionLanguageConfig) -> None: + super().__init__() + + self.vision_language_config = vision_language_config diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 83d2ddb2bcf35..1e5280dde3ff9 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -20,30 +20,29 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -54,17 +53,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -87,9 +86,9 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -114,13 +113,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -134,7 +133,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -156,7 +156,8 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -164,7 +165,6 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) self.self_attn = XverseAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -173,15 +173,15 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=getattr(config, "bias", False), - sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = XverseMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -222,7 +222,8 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -238,7 +239,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, linear_method) + XverseDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -296,13 +297,14 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = XverseModel(config, linear_method) + self.quant_config = quant_config + self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -332,11 +334,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -345,8 +343,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): diff --git a/vllm/model_executor/parallel_utils/README.md b/vllm/model_executor/parallel_utils/README.md deleted file mode 100644 index b25e3afddad9c..0000000000000 --- a/vllm/model_executor/parallel_utils/README.md +++ /dev/null @@ -1 +0,0 @@ -The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference. \ No newline at end of file diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py deleted file mode 100644 index 9cbb40708dd5b..0000000000000 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ /dev/null @@ -1,201 +0,0 @@ -from collections import namedtuple -from typing import Any, Dict, List, Optional, Union - -import torch -from torch.distributed import ProcessGroup - -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.custom_all_reduce import ( - custom_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) - - -def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: - """All-reduce the input tensor across model parallel group. - - NOTE: This operation will be applied in-place on the input tensor if - disable_custom_all_reduce is set to True. Otherwise, this operation may or - may not be applied in place depending on whether custom all reduce is - invoked for a particular tensor, which further depends on the tensor size - and GPU topology. - - TLDR: always assume this function modifies its input, but use the return - value as the output. - """ - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - out = custom_all_reduce(input_) - if out is not None: - return out - if is_pynccl_enabled_for_all_reduce(): - # TODO: support multiple parallel groups. - pynccl_utils.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) - return input_ - - -def tensor_model_parallel_all_gather(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: - """All-gather the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=get_tensor_model_parallel_group()) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor - - -def tensor_model_parallel_gather(input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> torch.Tensor: - """Gather the input tensor across model parallel group. - - NOTE: We assume that the input tensor is on the same device across - all the ranks. - """ - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - if get_tensor_model_parallel_rank() == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=dst, - group=get_tensor_model_parallel_group()) - if get_tensor_model_parallel_rank() == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor - - -def broadcast(input_: torch.Tensor, - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input tensor.""" - group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, src=src, group=group) - return input_ - - -def broadcast_object_list(obj_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input object list.""" - group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, src=src, group=group) - return obj_list - - -TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) - - -def broadcast_tensor_dict( - tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, -) -> Dict[Any, Union[torch.Tensor, Any]]: - """Broadcast the input tensor dictionary.""" - group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return tensor_dict - - rank = torch.distributed.get_rank() - if rank == src: - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list = [] - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append( - (key, TensorMetadata(value.dtype, value.size()))) - else: - metadata_list.append((key, value)) - torch.distributed.broadcast_object_list([metadata_list], - src=src, - group=group) - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) - else: - recv_metadata_list = [None] - torch.distributed.broadcast_object_list(recv_metadata_list, - src=src, - group=group) - metadata_list = recv_metadata_list[0] - tensor_dict = {} - async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device="cuda") - async_handle = torch.distributed.broadcast(tensor, - src=src, - async_op=True, - group=group) - async_handles.append(async_handle) - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - for async_handle in async_handles: - async_handle.wait() - return tensor_dict diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py deleted file mode 100644 index bf8ee07070c8a..0000000000000 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ /dev/null @@ -1,268 +0,0 @@ -from contextlib import contextmanager -from typing import Optional - -import torch -import torch.distributed as dist - -from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) - -try: - import pynvml - - from vllm._C import custom_ar -except ImportError: - # For AMD GPUs - custom_ar = None - pynvml = None - -logger = init_logger(__name__) - -_CA_HANDLE = None -_IS_CAPTURING = False -_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - -def init_custom_ar() -> None: - global _CA_HANDLE - if _CA_HANDLE is not None: - return - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - if world_size == 1: - # No need to initialize custom allreduce for single GPU case. - return - - if world_size not in _SUPPORTED_WORLD_SIZES: - logger.warn( - "Custom allreduce is disabled due to an unsupported world size: " - "%d. Supported world sizes: %s. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.", world_size, - str(_SUPPORTED_WORLD_SIZES)) - return - if not _can_p2p(rank, world_size): - logger.warn( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - full_nvlink = _is_full_nvlink(rank, world_size) - if world_size > 2 and not full_nvlink: - logger.warn( - "Custom allreduce is disabled because it's not supported on more" - " than two PCIe-only GPUs. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) - - -def begin_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = True - - -def end_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = False - - -def is_capturing() -> bool: - return _IS_CAPTURING and _CA_HANDLE is not None - - -def get_handle() -> Optional["CustomAllreduce"]: - return _CA_HANDLE - - -def is_initialized() -> bool: - return _CA_HANDLE is not None - - -@contextmanager -def capture(): - try: - begin_capture() - yield - finally: - end_capture() - handle = get_handle() - if handle is not None: - handle.register_graph_buffers() - - -def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: - ca_handle = get_handle() - # when custom allreduce is disabled, this will be None - if ca_handle is None: - return - if is_capturing(): - if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_reg(input) - else: - if ca_handle.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) - else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_unreg(input) - - -@contextmanager -def _nvml(): - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() - - -# query if the set of gpus are fully connected by nvlink (1 hop) -@_nvml() -def _is_full_nvlink(rank, world_size): - handle = pynvml.nvmlDeviceGetHandleByIndex(rank) - for i in range(world_size): - if i != rank: - try: - link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) - if not link_state: - return False - except pynvml.NVMLError as error: - logger.info( - f"NVLink detection failed with message \"{str(error)}\". " - "This is normal if your machine has no NVLink equipped") - return False - return True - - -def _can_p2p(rank: int, world_size: int) -> bool: - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warn( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return False - for i in range(world_size): - if i == rank: - continue - if not torch.cuda.can_device_access_peer(rank, i): - return False - # on some platforms, P2P support might be buggy and we need - # additional checks. See also: - # https://github.com/vllm-project/vllm/issues/2728 - if not _can_actually_p2p(rank, i): - return False - return True - - -# code partly borrowed from -# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 -# License: MIT -def _can_actually_p2p(idx_a, idx_b): - dev_i = f"cuda:{idx_a}" - dev_j = f"cuda:{idx_b}" - a = torch.randn(5, device=dev_i) + 123.0 - b = a.to(dev_j) - c = b.to(dev_i) - return torch.all(a == c) - - -class CustomAllreduce: - - # max_size: max supported allreduce size - def __init__(self, - rank, - world_size, - full_nvlink, - max_size=8192 * 1024) -> None: - # buffers memory are owned by this Python class and passed to C++ - # meta data composes of two parts: meta data for synchronization - # (256 bytes) and a temporary buffer for storing intermediate - # allreduce results. - self.meta = torch.zeros(custom_ar.meta_size() + max_size, - dtype=torch.uint8, - device="cuda") - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty(8 * 1024 * 1024, - dtype=torch.uint8, - device="cuda") - self.max_size = max_size - self.world_size = world_size - handles, offsets = self._get_ipc_meta(self.meta) - self.full_nvlink = full_nvlink - self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, - handles, offsets, rank, - self.full_nvlink) - self.register_buffer(self.buffer) - - def _get_ipc_meta(self, inp: torch.Tensor): - data = inp.untyped_storage()._share_cuda_() - shard_data = ( - data[1], # ipc handle to base ptr - data[3], # offset of base ptr - ) - return self._gather_ipc_meta(shard_data) - - def _gather_ipc_meta(self, shard_data): - all_data = [None] * self.world_size - dist.all_gather_object(all_data, shard_data) - - handles = [] - offsets = [] - for i in range(len(all_data)): - handles.append(all_data[i][0]) - offsets.append(all_data[i][1]) - return handles, offsets - - def register_buffer(self, inp: torch.Tensor): - handles, offsets = self._get_ipc_meta(inp) - custom_ar.register_buffer(self._ptr, inp, handles, offsets) - - def register_graph_buffers(self): - handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) - logger.info("Registering %d cuda graph addresses", len(offset)) - custom_ar.register_graph_buffers(self._ptr, handles, offsets) - - def should_custom_ar(self, inp: torch.Tensor): - return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) - - # all reduce, assuming inp tensor is IPC registered with register_buffer, - # or, in the context of cuda graphs, register_graph_buffers - def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): - if out is None: - out = torch.empty_like(inp) - custom_ar.all_reduce_reg(self._ptr, inp, out) - return out - - # all reduce, assuming inp tensor is NOT IPC registered - def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): - if out is None: - out = torch.empty_like(inp) - custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) - return out - - def close(self): - if self._ptr: - custom_ar.dispose(self._ptr) - self._ptr = 0 - - def __del__(self): - self.close() diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py deleted file mode 100644 index 15bd762fd3d14..0000000000000 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ /dev/null @@ -1,281 +0,0 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import datetime -import glob -import os - -# ===================== import region ===================== -import torch -import torch.distributed as dist -from torch.distributed import ReduceOp - -from vllm.logger import init_logger -from vllm.utils import is_hip - -logger = init_logger(__name__) - -so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") - -# check if we have vllm-managed nccl -vllm_nccl_path = None -if torch.version.cuda is not None: - cuda_major = torch.version.cuda.split(".")[0] - path = os.path.expanduser( - f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") - files = glob.glob(path) - vllm_nccl_path = files[0] if files else None - -# manually load the nccl library -if so_file: - logger.info( - f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") -else: - if torch.version.cuda is not None: - so_file = vllm_nccl_path or "libnccl.so.2" - elif torch.version.hip is not None: - so_file = "librccl.so.1" - else: - raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info(f"Loading nccl from library {so_file}") - -try: - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - f"Failed to load NCCL library from {so_file} ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.") - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - result = _c_ncclGetVersion(ctypes.byref(version)) - assert result == 0 - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) - assert result == 0 - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - - -# enums -class ncclDataType_t(ctypes.c_int): - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - - @classmethod - def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -class ncclRedOp_t(ctypes.c_int): - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p -] - -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - - -class NCCLCommunicator: - - def __init__( - self, - backend=None, - init_method=None, - timeout=datetime.timedelta(seconds=10), - world_size: int = -1, - rank: int = -1, - store=None, - group_name: str = "", - pg_options=None, - local_rank: int = -1, - ): - assert not is_hip() # We do not use pynccl on ROCm - if not dist.is_initialized(): - backend = backend or "nccl" - assert backend == 'nccl', ( - "only use nccl backend for starting the NCCL communicator") - dist.init_process_group(backend=backend, - init_method=init_method, - timeout=timeout, - world_size=world_size, - rank=rank, - store=store, - group_name=group_name, - pg_options=pg_options) - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - if local_rank == -1: - local_rank = self.rank - self.local_rank = local_rank - # don't use these args, as they can be -1 - # use `self.rank`, `self.local_rank` and `self.world_size` instead - del world_size, rank, local_rank - torch.cuda.set_device(self.local_rank) - if self.rank == 0: - self.unique_id = ncclGetUniqueId() - else: - self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( - self.local_rank) - dist.broadcast(tensor, src=0) - byte_list = tensor.cpu().tolist() - for i, byte in enumerate(byte_list): - self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank) - assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") - - def all_reduce(self, - tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): - if stream is None: - stream = self.stream - result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataType_t.from_torch(tensor.dtype), - ncclRedOp_t.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream)) - assert result == 0 - - def __del__(self): - # `dist` module might have been already destroyed - if hasattr(dist, 'destroy_process_group'): - dist.destroy_process_group() - # function might have been already destroyed - if _c_ncclCommDestroy is not None: - _c_ncclCommDestroy(self.comm) diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py deleted file mode 100644 index a099777aa0005..0000000000000 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import contextlib -from typing import Optional - -import torch -from torch.distributed import ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetVersion) -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info(f"Failed to import NCCL library: {e}") - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - comm.stream = stream - yield - finally: - pass - - -def init_process_group(world_size: int, - rank: int, - init_method: str, - local_rank: int = -1) -> None: - assert not is_initialized() - global comm - logger.info(f"vLLM is using nccl=={ncclGetVersion()}") - comm = NCCLCommunicator(init_method=init_method, - world_size=world_size, - local_rank=local_rank, - rank=rank) - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - comm.all_reduce(input_, op) - - -def destroy_process_group() -> None: - global comm - comm = None - - -def get_world_size() -> int: - """Returns the world size.""" - return comm.world_size - - -def get_nccl_backend(): - return comm diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 0000000000000..b86cafce85d12 --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 534cb75c2fd2f..9969c45963e9a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -6,57 +6,284 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData -from vllm.utils import is_pin_memory_available +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.utils import (async_tensor_h2d, is_pin_memory_available, + maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +@dataclass +class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Sequence ids for the sequence group in a previous step. + seq_ids: List[int] + sampling_params: SamplingParams + # seq_id -> sequence data. + seq_data: Dict[int, SequenceData] + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode + # stage. + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] + # A random number generator for sampling. + generator: Optional[torch.Generator] + # True if the sequence group is in prefill stage. False if it is in a + # decode stage. + is_prompt: bool + # Query token indices from logits. to compute prompt logprob. Empty if + # prompt logprob is not required. + prompt_logprob_indices: List[int] + # Sample token indices from logits. Empty if sampling is not required. + sample_indices: List[int] + + @property + def do_sample(self): + return len(self.sample_indices) > 0 + + def __post_init__(self): + if len(self.prompt_logprob_indices) > 0: + assert self.sampling_params.prompt_logprobs is not None + if self.is_prompt: + assert self.seq_len is not None + assert self.query_len is not None + + class SamplingMetadata: """Metadata for input sequences. Used in sampler. + The usage is as follow; + ``` + hidden_states = execute_model(...) + logits = hidden_states[sampling_metadata.selected_token_indices] + sample(logits) + + def sample(logits): + # Use categorized_sample_indices for sampling.... + ``` + Args: - seq_groups: List of (seq_ids, sampling_params). - seq_data: Seq_id -> SequenceData. - prompt_lens: Lengths of prompts. - selected_token_indices: Token indices selected for sampling. + seq_groups: List of batched sequence groups. + selected_token_indices: (num_query_tokens_to_logprob). Indices to find + logits from the initial model output hidden states. categorized_sample_indices: SamplingType -> token indices to sample. - generators: List of torch.Generators to use for seeded sampling - perform_sampling: Whether to perform sampling. This option is used to - make the sampling only happens in the driver worker, and disable - sampling in other worker processes. + Each token indices is 2D tensor of (num_indices, num_indices) where + the first item means the sample index within the returned logit + (before pruning padding), and the second item means the sample + index after pruning using selected_token_indices. + For example, if the returned logit is [1, 2, 3], and we select + [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, + The first tuple is [1, 2] (sampled index within original logit), + and the second tuple is [0, 1] (sampled index within pruned logit). + num_prompts: Number of prompt sequence groups in seq_groups. """ def __init__( self, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], - seq_data: Optional[Dict[int, SequenceData]], - prompt_lens: Optional[List[int]], + seq_groups: List[SequenceGroupToSample], selected_token_indices: torch.Tensor, - categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], - generators: Optional[List[torch.Generator]] = None, - perform_sampling: bool = True, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + num_prompts: int, ) -> None: self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices - self.generators = generators - self.perform_sampling = perform_sampling + self.num_prompts = num_prompts - self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 + @staticmethod + def prepare( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: Optional[List[int]], + device: str, + pin_memory: bool, + ) -> "SamplingMetadata": + ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device) + selected_token_indices = async_tensor_h2d(selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + categorized_sample_indices = { + t: maybe_expand_dim( + async_tensor_h2d(seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + num_prompts=num_prompts, + ) + return sampling_metadata def __repr__(self) -> str: return ( "SamplingMetadata(" f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices}), " - f"perform_sampling={self.perform_sampling})") + f"categorized_sample_indices={self.categorized_sample_indices}), ") + + +def _prepare_seq_groups( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: Optional[List[int]], + device: str, +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ + SamplingType, List[Tuple[int, int]]], int]: + """Prepare sequence groups and indices for sampling. + + Args: + seq_group_metadata_list: A list of sequence group to batch. + seq_lens: A list of sequence lens per sequence group. + Index of prompt len should match with seq_group_metadata_list. + query_lens: A list of query lengths. Prompt lens include the length + of entire prompt tokens, and it could be shorter. + device: A device to use for random number generator, + `SequenceGroupToSample.generator`. + + Returns: + seq_groups: A list of sequence group to sample. + selected_token_indices: See the definition from `SamplingMetadata`. + categorized_sample_indices: See the definition from `SamplingMetadata`. + num_prompts: Total number of prompts from `seq_group_metadata_list`. + """ + # Batched sequence groups for the current model forward stsep. + seq_groups: List[SequenceGroupToSample] = [] + # A list of token indices to sample/compute logprob. It is used to + # prune the outcome logits from the model for the performance. + selected_token_indices: List[int] = [] + # Used for selected_token_indices. + model_output_idx = 0 + + # Sampling type -> ( + # indices to sample/prompt logprob within pruned output logits, + # indices to sample within pruned logits) + categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } + # Index of logits to compute logprob. Logits include both prompt logprob + # and sample logprob indices. + logit_idx = 0 + # Index to sample from a sample tensor. It is used by triton sample kernel. + # See `_sample_with_triton_kernel` for more details. + sample_idx = 0 + # Total number of prompts from given sequence groups. + num_prompts = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + is_prompt = seq_group_metadata.is_prompt + generator: Optional[torch.Generator] = None + # If the current seq group is in decode stage, it is None. + seq_len: Optional[int] = None + query_len: Optional[int] = None + prompt_logprob_indices: List[int] = [] + sample_indices: List[int] = [] + do_sample = seq_group_metadata.do_sample + + if seq_group_metadata.is_prompt: + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=device).manual_seed(sampling_params.seed) + + num_prompts += 1 + num_prefill_sample = len(seq_ids) + assert num_prefill_sample == 1 + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] + # If we need sampling, exclude num_prefill_sample tokens from + # prompt logprob. + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + # Decode + prompt_logprob_len = 0 + sample_len = len(seq_ids) if do_sample else 0 + + # Update indices to select from the model output. + """ + This blocks computes selected_token_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + """ + + if sampling_params.prompt_logprobs: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + prompt_logprob_len)) + model_output_idx += prompt_logprob_len + if do_sample: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + sample_len)) + model_output_idx += sample_len + + # We now find indices for logprob computation and sampling. + """ + This block computes categorized_sample_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + def sample(logits): + # Use categorized_sample_indices for sampling. + # prompt_logprob_indices to find prompt logprob indices. + # sample_indices to find sample indices. + """ + + if sampling_params.prompt_logprobs is not None: + prompt_logprob_indices.extend( + range(logit_idx, logit_idx + prompt_logprob_len)) + logit_idx += prompt_logprob_len + if do_sample: + sample_indices.extend(range(logit_idx, logit_idx + sample_len)) + categorized_sample_indices[sampling_params.sampling_type].extend( + list( + zip(range(logit_idx, logit_idx + sample_len), + range(sample_idx, sample_idx + sample_len)))) + logit_idx += sample_len + sample_idx += sample_len + + if sampling_params.seed is not None: + generator = seq_group_metadata.state.generator + + seq_groups.append( + SequenceGroupToSample( + seq_ids=seq_ids, + sampling_params=sampling_params, + seq_data=seq_group_metadata.seq_data, + seq_len=seq_len, + query_len=query_len, + generator=generator, + is_prompt=is_prompt, + prompt_logprob_indices=list(prompt_logprob_indices), + sample_indices=list(sample_indices))) + return (seq_groups, selected_token_indices, categorized_sample_indices, + num_prompts) @dataclass @@ -112,9 +339,10 @@ def from_sampling_metadata( seeds_to_generate = (extra_seeds_to_generate + get_num_triton_sampler_splits(vocab_size)) - sample_indices_start_idx = 0 - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + assert sampling_metadata.seq_groups is not None + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params temperature = sampling_params.temperature p = sampling_params.presence_penalty f = sampling_params.frequency_penalty @@ -143,43 +371,46 @@ def from_sampling_metadata( or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True - if (i < sampling_metadata.num_prompts + is_prompt = seq_group.is_prompt + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - prompt_tokens.extend([] for _ in range(prompt_len - 1)) - output_tokens.extend([] for _ in range(prompt_len - 1)) - for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - - is_prompt = i < sampling_metadata.num_prompts + query_len = seq_group.query_len + assert query_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + min_ps += [min_p] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens == len(seq_ids) + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + if is_prompt: prompt_best_of.append(sampling_params.best_of) - prompt_len = sampling_metadata.prompt_lens[i] + query_len = seq_group.query_len + assert query_len is not None - if sampling_params.prompt_logprobs is not None: - # NOTE: the sampling position is the last token - # in the prompt - sample_indices_start_idx += prompt_len - 1 for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] + seq_data = seq_group.seq_data[seq_id] extra_entropy = extra_entropy or () seq_seeds = cls._get_sequence_seeds( seed, @@ -189,8 +420,7 @@ def from_sampling_metadata( seeds_to_generate=seeds_to_generate, is_greedy=is_greedy) sampling_seeds.append(seq_seeds) - sample_indices.append(sample_indices_start_idx) - sample_indices_start_idx += 1 + sample_indices.extend(seq_group.sample_indices) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, @@ -213,12 +443,14 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() - prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_max_len = max([len(tokens) for tokens in prompt_tokens], + default=0) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) for tokens in prompt_tokens ] - output_max_len = max(len(tokens) for tokens in output_tokens) + output_max_len = max([len(tokens) for tokens in output_tokens], + default=0) output_padded_tokens = [ tokens + [vocab_size] * (output_max_len - len(tokens)) for tokens in output_tokens diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py deleted file mode 100644 index 8016941216994..0000000000000 --- a/vllm/model_executor/weight_utils.py +++ /dev/null @@ -1,375 +0,0 @@ -"""Utilities for downloading and initializing model weights.""" -import fnmatch -import glob -import hashlib -import json -import os -from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple - -import filelock -import huggingface_hub.constants -import numpy as np -import torch -from huggingface_hub import HfFileSystem, snapshot_download -from safetensors.torch import load_file, safe_open, save_file -from tqdm.auto import tqdm - -from vllm.config import ModelConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QuantizationConfig, - get_quantization_config) -from vllm.model_executor.layers.quantization.schema import QuantParamSchema - -logger = init_logger(__name__) - -# use system-level temp directory for file locks, so that multiple users -# can share the same lock without error. -# lock files in the temp directory will be automatically deleted when the -# system reboots, so users will not complain about annoying lock files -temp_dir = os.environ.get('TMPDIR') or os.environ.get( - 'TEMP') or os.environ.get('TMP') or "/tmp/" - - -def enable_hf_transfer(): - """automatically activates hf_transfer - """ - if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: - try: - # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa - huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True - except ImportError: - pass - - -enable_hf_transfer() - - -class Disabledtqdm(tqdm): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, disable=True) - - -def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): - lock_dir = cache_dir or temp_dir - os.makedirs(os.path.dirname(lock_dir), exist_ok=True) - model_name = model_name_or_path.replace("/", "-") - hash_name = hashlib.sha256(model_name.encode()).hexdigest() - # add hash to avoid conflict with old users' lock files - lock_file_name = hash_name + model_name + ".lock" - # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), - mode=0o666) - return lock - - -def _shared_pointers(tensors): - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - failing = [] - for _, names in ptrs.items(): - if len(names) > 1: - failing.append(names) - return failing - - -def convert_bin_to_safetensor_file( - pt_filename: str, - sf_filename: str, -) -> None: - loaded = torch.load(pt_filename, map_location="cpu") - if "state_dict" in loaded: - loaded = loaded["state_dict"] - shared = _shared_pointers(loaded) - for shared_weights in shared: - for name in shared_weights[1:]: - loaded.pop(name) - - # For tensors to be contiguous - loaded = {k: v.contiguous() for k, v in loaded.items()} - - dirname = os.path.dirname(sf_filename) - os.makedirs(dirname, exist_ok=True) - save_file(loaded, sf_filename, metadata={"format": "pt"}) - - # check file size - sf_size = os.stat(sf_filename).st_size - pt_size = os.stat(pt_filename).st_size - if (sf_size - pt_size) / pt_size > 0.01: - raise RuntimeError(f"""The file size different is more than 1%: - - {sf_filename}: {sf_size} - - {pt_filename}: {pt_size} - """) - - # check if the tensors are the same - reloaded = load_file(sf_filename) - for k in loaded: - pt_tensor = loaded[k] - sf_tensor = reloaded[k] - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") - - -# TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: - quant_cls = get_quantization_config(model_config.quantization) - # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", - None) - if hf_quant_config is not None: - return quant_cls.from_config(hf_quant_config) - model_name_or_path = model_config.model - is_local = os.path.isdir(model_name_or_path) - if not is_local: - # Download the config files. - with get_lock(model_name_or_path, model_config.download_dir): - hf_folder = snapshot_download(model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=model_config.download_dir, - tqdm_class=Disabledtqdm) - else: - hf_folder = model_name_or_path - config_files = glob.glob(os.path.join(hf_folder, "*.json")) - - quant_config_files = [ - f for f in config_files if any( - f.endswith(x) for x in quant_cls.get_config_filenames()) - ] - if len(quant_config_files) == 0: - raise ValueError( - f"Cannot find the config file for {model_config.quantization}") - if len(quant_config_files) > 1: - raise ValueError( - f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}") - - quant_config_file = quant_config_files[0] - with open(quant_config_file, "r") as f: - config = json.load(f) - return quant_cls.from_config(config) - - -def prepare_hf_model_weights( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - fall_back_to_pt: bool = True, - revision: Optional[str] = None, -) -> Tuple[str, List[str], bool]: - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - if os.path.isfile(model_name_or_path): - return (os.path.dirname(model_name_or_path), - [model_name_or_path], - True) - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npcache": - allow_patterns = ["*.bin"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info(f"Using model weights format {allow_patterns}") - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision) - else: - hf_folder = model_name_or_path - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - -def hf_model_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - fall_back_to_pt: Optional[bool] = True, -) -> Iterator[Tuple[str, torch.Tensor]]: - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( - model_name_or_path, - cache_dir=cache_dir, - load_format=load_format, - fall_back_to_pt=fall_back_to_pt, - revision=revision) - - if load_format == "npcache": - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - - # Convert the model weights from torch tensors to numpy arrays for - # faster loading. - np_folder = os.path.join(hf_folder, "np") - os.makedirs(np_folder, exist_ok=True) - weight_names_file = os.path.join(np_folder, "weight_names.json") - # Use file lock to prevent multiple processes from - # dumping the same model weights to numpy at the same time. - with get_lock(model_name_or_path, cache_dir): - if not os.path.exists(weight_names_file): - weight_names = [] - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - param_path = os.path.join(np_folder, name) - with open(param_path, "wb") as f: - np.save(f, param.cpu().detach().numpy()) - weight_names.append(name) - with open(weight_names_file, "w") as f: - json.dump(weight_names, f) - - with open(weight_names_file, "r") as f: - weight_names = json.load(f) - - for name in weight_names: - param_path = os.path.join(np_folder, name) - with open(param_path, "rb") as f: - param = np.load(f) - yield name, torch.from_numpy(param) - elif use_safetensors: - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - else: - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() - - -def kv_cache_scales_loader( - filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, - model_type: Optional[str]) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - Keep this function in sync with the output of examples/fp8/extract_scales.py - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, - context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error(f"File or directory '{filename}' not found.") - except json.JSONDecodeError: - logger.error(f"Error decoding JSON in file '{filename}'.") - except Exception as e: - logger.error(f"An error occurred while reading '{filename}': {e}") - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning("Defaulting to KV cache scaling factors = 1.0 " - f"for all layers in TP rank {tp_rank} " - "as an error occurred during loading.") - return [] - - -def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: - """convert PySafeSlice object from safetensors to torch.Tensor - - PySafeSlice object supports indexing, which is done before loading the - actual tensor and can reduce the amount of memory being read into the - memory. However, it does not support more advanced functionalities - like `.view()` or `.t()`. Therefore, if we need to modify the loaded - tensor with these more complicated operators, we need to convert to - tensor first. - """ - if not isinstance(x, torch.Tensor): - x = x[:] - return x - - -def default_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - """Default weight loader.""" - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - - -def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, -) -> None: - """Initialize model weights with random values. - - The model weights must be randomly initialized for accurate performance - measurements. Additionally, the model weights should not cause NaNs in the - forward pass. We empirically found that initializing the weights with - values between -1e-3 and 1e-3 works well for most models. - """ - for param in model.state_dict().values(): - if torch.is_floating_point(param): - param.data.uniform_(low, high) diff --git a/vllm/outputs.py b/vllm/outputs.py index 61fe20bfc2744..49f526b5f9300 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None @@ -57,8 +48,24 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass +class EmbeddingOutput: + """The output data of one completion output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. The + length of vector depends on the model as listed in the embedding guide. + """ + + embedding: List[float] + + def __repr__(self) -> str: + return (f"EmbeddingOutput(" + f"embedding={len(self.embedding)})") + + class RequestOutput: - """The output data of a request to the LLM. + """The output data of a completion request to the LLM. Args: request_id: The unique ID of the request. @@ -74,7 +81,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], @@ -93,6 +100,9 @@ def __init__( @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + if seq_group.sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -112,8 +122,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # always has the logprobs of the sampled tokens even if the # logprobs are not requested. include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), seq.output_text, + CompletionOutput(seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, @@ -146,3 +158,61 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request})") + + +class EmbeddingRequestOutput: + """ + The output data of an embedding request to the LLM. + + Args: + request_id (str): A unique identifier for the embedding request. + outputs (EmbeddingOutput): The embedding results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the embedding is completed. + """ + + def __init__(self, request_id: str, outputs: "EmbeddingOutput", + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + if seq_group.embeddings is None: + raise ValueError( + "Embeddings are missing in seq_group for EmbeddingRequest.") + output = EmbeddingOutput(seq_group.embeddings) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 0000000000000..3b95d73ddc2c5 --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + + +class PoolingParams: + """Pooling parameters for pooling. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + + def __init__(self, additional_data: Optional[Any] = None): + self.additional_data = additional_data + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data, ) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4fdc3c6dedaef..9d8a361353e26 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,10 +2,11 @@ import copy from enum import IntEnum from functools import cached_property -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated _SAMPLING_EPS = 1e-5 @@ -17,10 +18,14 @@ class SamplingType(IntEnum): BEAM = 3 -LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] -"""LogitsProcessor is a function that takes a list of previously generated -tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from.""" +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[List[int], List[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" class SamplingParams: @@ -94,7 +99,8 @@ class SamplingParams: spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on - previously generated tokens. + previously generated tokens, and optionally prompt tokens as + a first argument. truncate_prompt_tokens: If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). @@ -127,7 +133,7 @@ def __init__( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -138,7 +144,10 @@ def __init__( self.top_p = top_p self.top_k = top_k self.min_p = min_p - self.seed = seed + if seed == -1: + self.seed = None + else: + self.seed = seed self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping @@ -166,6 +175,13 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + else: + self.output_text_buffer_length = 0 + self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -177,8 +193,8 @@ def __init__( self.top_k = -1 self.min_p = 0.0 self._verify_greedy_sampling() - # injected by the engine - self.eos_token_id = None + # eos_token_id is added to this by the engine + self.all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: @@ -226,6 +242,8 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " @@ -261,6 +279,19 @@ def _verify_greedy_sampling(self) -> None: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") + def update_from_generation_config( + self, generation_config: Dict[str, Any]) -> None: + """Update if there are non-default values from generation_config""" + # Update eos_token_id for generation + if (not self.ignore_eos) and (eos_ids := + generation_config.get("eos_token_id")): + # it can be either int or list of int + if isinstance(eos_ids, int): + eos_ids = [eos_ids] + original_stop_token_ids = set(self.stop_token_ids) + original_stop_token_ids.update(eos_ids) + self.stop_token_ids = list(original_stop_token_ids) + @cached_property def sampling_type(self) -> SamplingType: if self.use_beam_search: diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c4..ac5c234d052bd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,11 +1,14 @@ """Sequence and its related classes.""" import copy import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -28,7 +31,10 @@ class Logprob: decoded_token: Optional[str] = None +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. PromptLogprobs = List[Optional[Dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. SampleLogprobs = List[Dict[int, Logprob]] @@ -116,6 +122,7 @@ def __init__( output_token_ids = [] self.prompt_token_ids = prompt_token_ids + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -138,6 +145,17 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_prefix_token_ids( + self, num_tokens: int + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = len(self.prompt_token_ids) + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self.output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -160,7 +178,7 @@ def reset_state_for_recompute(self) -> None: self._stage = SequenceStage.PREFILL def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefil tokens that are not computed.""" + """Return the number of prefill tokens that are not computed.""" # we use `get_len()` which includes prompt_len + output_len instead # of prompt_len here. This is because during recompute we need to # prefill for both prompt and output. @@ -171,10 +189,10 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] - def get_prompt_token_ids(self) -> int: + def get_prompt_token_ids(self) -> List[int]: return self.prompt_token_ids - def get_output_token_ids(self) -> int: + def get_output_token_ids(self) -> List[int]: return self.output_token_ids @property @@ -193,8 +211,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -203,25 +220,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -231,10 +247,28 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs.get("prompt") + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs.get("multi_modal_data") + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def get_output_text_to_return(self, buffer_length: int): + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -242,8 +276,8 @@ def hash_of_block(self, logical_idx: int) -> int: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - return hash( - (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size @@ -339,12 +373,9 @@ def fork(self, new_seq_id: int) -> "Sequence": def get_num_new_tokens(self) -> int: """Get the number of new tokens to be computed. - Args: - remainig_token_budget: The remaining token budgets. Returns: - The new number of tokens to be computed. I.e., 1 for decode, prompt - size for prefill. If there's not enough remainig_token_budget, it - can return the chunked number of new tokens. + The new number of tokens to be computed. I.e., 1 for decode, or + the remaining prompt size for prefill. """ if self.data.stage == SequenceStage.DECODE: return 1 @@ -364,17 +395,17 @@ class SequenceGroupState: """Mutable state tied to a specific sequence group""" # torch.Generator used in seeded sampling - generator: Optional = None + generator: Optional = None # type: ignore class MultiModalData: """Multi modal request. - + Args: type: The data type. data: The actual data. The required shape and semantic meaning of it depends on the vision - language config of the hosted model. + language config of the hosted model. See `VisionLanguageConfig` in `config.py`. """ @@ -395,17 +426,24 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. + embeddings: The embeddings vectors of the prompt of the sequence group + for an embedding model. + pooling_params: The pooling parameters used to generate the pooling + for an embedding model. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( self, request_id: str, seqs: List[Sequence], - sampling_params: SamplingParams, arrival_time: float, + sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + embeddings: Optional[List[float]] = None, + pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -418,10 +456,12 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data + self.embeddings = embeddings + self.pooling_params = pooling_params + self.encoder_seq = encoder_seq @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -430,21 +470,39 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - def get_last_latency(self, now: float) -> float: - """Gets last token latency for Request level timings.""" + def get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, raise Error. + if self.is_prefill(): + raise ValueError( + "seq_group.get_last_latency() should not be called " + "if the seq_group is in prefill phase.") + + # Otherwise return token latency. latency = now - self.metrics.last_token_time self.metrics.last_token_time = now return latency def maybe_set_first_token_time(self, time: float) -> None: """Sets the first token time for Request level timings.""" - if self.metrics.first_token_time is None: + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None + and self.get_seqs()[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -461,12 +519,13 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params.use_beam_search: + if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. return self.sampling_params.best_of else: - if self.sampling_params.best_of > self.num_seqs(): + if (self.sampling_params + and self.sampling_params.best_of > self.num_seqs()): # At prompt stage, the sequence group is not yet filled up # and only have one sequence running. However, in the # generation stage, we will have `best_of` sequences running. @@ -483,6 +542,12 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq + def get_unfinished_seqs(self) -> List[Sequence]: return [ seq for seq in self.seqs_dict.values() if not seq.is_finished() @@ -500,10 +565,16 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + # Optimization. We don't need to call get_seqs if we don't need to + # filter by states. + if status is None: + return len(self.seqs_dict) + return len(self.get_seqs(status)) def num_unfinished_seqs(self) -> int: @@ -531,7 +602,7 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def is_prefill(self) -> bool: - # Every sequences should be in the same stage. + # Every sequence should be in the same stage. return self.get_seqs()[0].is_prefill() def __repr__(self) -> str: @@ -550,11 +621,25 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + do_sample: True if sampling is required. Sampling is not required when + e.g., prefill is chunked, and the current iteration only computes + query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. - state: Internal state tied to this sequence group. lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( @@ -564,22 +649,36 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + do_sample: bool = True, + pooling_params: Optional[PoolingParams] = None, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, + encoder_seq_data: Optional[SequenceData] = None, + cross_block_table: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.pooling_params = pooling_params self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.encoder_seq_data = encoder_seq_data + self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size + self.do_sample = do_sample + + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + self.num_speculative_tokens = None if self._token_chunk_size is None: if is_prompt: @@ -594,6 +693,7 @@ def lora_int_id(self) -> int: @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" + assert self._token_chunk_size is not None return self._token_chunk_size @@ -632,8 +732,20 @@ def __eq__(self, other: object) -> bool: return equal and log_probs_equal -class SequenceGroupOutput: - """The model output associated with a sequence group.""" +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with a completion sequence group.""" def __init__( self, @@ -641,33 +753,56 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], ) -> None: self.samples = samples + # Prompt logprob for each prompt query token. self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutput): + if not isinstance(other, CompletionSequenceGroupOutput): raise NotImplementedError() return (self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) +class EmbeddingSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with an embedding sequence group.""" + + def __init__( + self, + embeddings: List[float], + ) -> None: + self.embeddings = embeddings + + def __repr__(self) -> str: + return (f"EmbeddingSequenceGroupOutput(" + f"embeddings_shape={len(self.embeddings)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EmbeddingSequenceGroupOutput): + raise NotImplementedError() + return self.embeddings == other.embeddings + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. - This datastructure implements methods so it can be used like a list, but + This data structure implements methods, so it can be used like a list, but also has optional fields for device tensors. """ - outputs: List[SequenceGroupOutput] + outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + # On-device tensor containing the sampled token ids. sampled_token_ids: Optional["torch.Tensor"] = None @@ -686,3 +821,67 @@ def __len__(self): def __eq__(self, other: object): return isinstance(other, self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +@dataclass +class PoolerOutput: + """The output from a pooling operation in the embedding model.""" + outputs: List[EmbeddingSequenceGroupOutput] + + spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + +@dataclass +class ExecuteModelRequest: + """The model execution request.""" + # The sequence group metadata list. + seq_group_metadata_list: List[SequenceGroupMetadata] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + + def clone( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e0b75837e8a39..7792f3a3425cc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,15 +1,16 @@ from itertools import chain, count -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Iterator, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, sampler_output_to_torch, split_batch_by_proposal_len) -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -31,7 +32,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): self._scorer_worker = scorer_worker self._device = device self._vocab_size = vocab_size @@ -39,11 +41,7 @@ def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. @@ -56,11 +54,7 @@ def score_proposals( no speculation is produced for that sequence. Args: - seq_group_metadata_list: The input sequence group metadata. - blocks_to_swap_in: This is passed to the worker during scoring. - blocks_to_swap_out: This is passed to the worker during scoring. - blocks_to_copy: This is passed to the worker during scoring. - k: The fixed proposal length. + execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with @@ -71,39 +65,45 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() + # Filter the list to ignore -1 proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if -1 not in proposals + ] + (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( - seq_group_metadata_list=seq_group_metadata_list, - proposal_token_ids_list=proposal_token_ids_list, + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( - seq_group_metadata_list=target_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - return_python_output=False) - - all_tokens, all_probs = self._contract_batch( - original_bs=len(seq_group_metadata_list), + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list, )) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] + + all_tokens, all_probs, spec_logprobs = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, - k=k, + k=execute_model_req.num_lookahead_slots, ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, + logprobs=spec_logprobs, ) def _expand_batch( self, seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids_list: List[TokenId], + proposal_token_ids_list: List[List[TokenId]], proposal_lens_list: List[int], ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: """Given the input sequences and potentially multiple corresponding @@ -125,71 +125,102 @@ def _expand_batch( select_proposal_len_zero=True) target_seq_group_metadata_list = self._create_scoring_model_input( - spec_seqs, proposal_token_ids_list) + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + num_scoring_tokens = len(target_seq_group_metadata_list) target_seq_group_metadata_list.extend(non_spec_seqs) return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, original_bs: int, - target_sampler_output: List[SamplerOutput], - proposals: SpeculativeProposals, - num_scoring_tokens: int, non_spec_indices: List[int], - spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + def _contract_batch( + self, contracted_bs: int, + target_sampler_output: List[SamplerOutput], + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. + + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) = self._split_scoring_output( + (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token # of shape [batch_size * k + 1] back to [batch_size, k + 1]. - batch_size, k = proposals.proposal_token_ids.shape + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences. + non_spec_expanded_bs, _ = non_spec_target_token_ids.shape + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.squeeze().reshape( - batch_size, k + 1) - target_probs = target_probs.squeeze().reshape(batch_size, k + 1, + spec_expanded_bs, k + 1) + target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) + target_logprobs = target_logprobs.squeeze().reshape( + spec_expanded_bs, k + 1, self._vocab_size) - all_tokens = torch.full(size=(original_bs, k + 1), + all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, device=self._device, dtype=torch.long) - all_probs = torch.zeros(original_bs, + all_probs = torch.zeros(contracted_bs, k + 1, self._vocab_size, device=self._device, dtype=torch.float32) + all_logprobs = torch.full(size=( + contracted_bs, + k + 1, + self._vocab_size, + ), + fill_value=-float("inf"), + device=self._device, + dtype=torch.float32) if non_spec_indices: - all_tokens[non_spec_indices, 0] = non_spec_target_token_ids + all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs + all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs + return all_tokens, all_probs, all_logprobs def _create_scoring_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: """Given the original input sequences and proposed tokens from the draft model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. """ if not seq_group_metadata_list: return [] - target_seq_ids_iter = self._create_target_seq_id_iterator( - get_all_seq_ids(seq_group_metadata_list)) - target_seq_group_metadata = list( chain.from_iterable( self._create_target_seq_group_metadata( @@ -205,7 +236,7 @@ def _create_scoring_model_input( def _create_target_seq_group_metadata( self, input_seq_group_metadata: SequenceGroupMetadata, - proposal_token_ids: List[TokenId], # shape: [batch_size, k] + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] batch_index: int, target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: @@ -262,26 +293,36 @@ def _create_single_target_seq_group_metadata( prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), - }, + seq_data=new_seq_data_dict, sampling_params=seq_group_metadata.sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, + token_chunk_size=1, ) def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: """Split the target model output into speculative and non-speculative output. """ @@ -301,21 +342,29 @@ def _split_scoring_output( ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens - target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output]) + sampler_output.logprobs = spec_logprobs + (target_token_ids, target_probs, + target_logprobs) = sampler_output_to_torch([sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens - non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output])) + sampler_output.logprobs = non_spec_logprobs + (non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], + True) - return (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) + return (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: @@ -347,7 +396,7 @@ def _get_token_ids_to_score( [0, 1, 2] [0, 1, 2, 3] """ - empty_token_ids = [] + empty_token_ids: List[TokenId] = [] token_ids_to_score = [empty_token_ids] token_ids_to_score.extend([ diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 2a72974d01bdc..d311bfe984cbc 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple import torch -from vllm.sequence import SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest @dataclass @@ -24,9 +23,9 @@ class SpeculativeProposals: def __repr__(self): return (f"SpeculativeProposals(" - f"proposal_token_ids={self.proposal_token_ids.shape}, " + f"proposal_token_ids={self.proposal_token_ids}, " f"proposal_probs={self.proposal_probs.shape}, " - f"proposal_lens={self.proposal_lens.shape})") + f"proposal_lens={self.proposal_lens})") @dataclass @@ -38,6 +37,11 @@ class SpeculativeScores: # Probabilities of the speculative tokens according to the scoring model. probs: torch.Tensor + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + # Token ids sampled from the scoring model. Used for speculative bonus # tokens and also non-speculative normal decoding. token_ids: torch.Tensor @@ -53,11 +57,7 @@ class SpeculativeProposer(ABC): @abstractmethod def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: raise NotImplementedError @@ -67,11 +67,7 @@ class SpeculativeScorer(ABC): @abstractmethod def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 5df8fc4316d48..ab1d96c558de7 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: Returns a CUDA event recording when the copy is complete. """ + assert self._copy_stream is not None self._copy_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._copy_stream): @@ -146,15 +147,16 @@ def _collect_rejsample_metrics( emitted_tokens = self._aggregate_num_emitted_tokens.item() draft_tokens = self._aggregate_num_draft_tokens - num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) if draft_tokens > 0: draft_acceptance_rate = accepted_tokens / draft_tokens else: draft_acceptance_rate = float("nan") - if num_possible_tokens > 0: - system_efficiency = emitted_tokens / num_possible_tokens + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens else: system_efficiency = float("nan") @@ -168,8 +170,22 @@ def _collect_rejsample_metrics( ) @staticmethod - def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: - # Divide by k since batch size can be variable. - total_num_spec_seqs = draft_tokens / k - num_accepted_per_seq_if_all_accepted = k + 1 - return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 73b6e201c67a9..b5a805278d273 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,12 +1,13 @@ import copy -from typing import Dict, List, Optional, Tuple +import weakref +from typing import List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -25,76 +26,73 @@ class MultiStepWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._proposer: Optional[DraftModelTop1Proposer] = None + # Lazy initialization list. + self._proposer: Top1Proposer def init_device(self): super().init_device() - self._proposer = DraftModelTop1Proposer( - self, + self._proposer = Top1Proposer( + weakref.proxy(self), self.device, - self.max_model_len, self.vocab_size, + max_proposal_len=self.max_model_len, ) + def set_include_gpu_probs_tensor(self): + # Need include_gpu_probs_tensor for multi_step_worker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + @torch.inference_mode() - def execute_model_multi_step( + def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_steps: int, - ) -> List[SamplerOutput]: - """Run the model forward pass num_steps times. Returns the list of - sampler output, one per model forward pass. + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. """ - self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, - blocks_to_swap_out, blocks_to_copy) + self._raise_if_unsupported(execute_model_req) # Shallow copy input data so modifications (such as appending tokens) # do not cause side-effects. copied_seq_group_metadata_list = self._shallow_copy_inputs( - seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) + copied_execute_model_req = execute_model_req.clone( + copied_seq_group_metadata_list) - # Assert enough KV space for num_steps tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + # Assert enough KV space for sample_len tokens per sequence. + self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, + sample_len) - # Run model num_steps times. + # Run model sample_len times. model_outputs = [] - for _ in range(num_steps): + for _ in range(sample_len): model_output = super().execute_model( - seq_group_metadata_list=copied_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=copied_execute_model_req) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] self._append_new_tokens(model_output, copied_seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs + return model_outputs, True def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _append_new_tokens( self, model_output: SamplerOutput, @@ -116,6 +114,7 @@ def _append_new_tokens( token_logprob = seq_output.logprobs[token_id] seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -185,186 +184,22 @@ def _assert_enough_kv_space( def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """MultiStepWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "MultiStepWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") - - -class DraftModelTop1Proposer(SpeculativeProposer): - """Helper class which separates out sequences which would exceed the max - model length when speculated upon. - - This allows combinations of models such as JackFram/llama-68m draft with - meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of - 2048 while Llama2-13b has max_position_embeddings of 4096. - - We treat the sequences which exceed the proposal draft model length as - "non-spec sequences". Essentially they skip the draft model and go through - normal decoding in the target model. - - Currently, only proposal_lens of 0 and k are supported, where k is a global - batch proposal length. In the future vLLM should support per-sequence - proposal lengths. - """ - - def __init__( - self, - draft_worker: MultiStepWorker, - device: str, - max_model_len: int, - vocab_size: int, - ): - self._draft_worker = draft_worker - self._device = device - self._max_model_len = max_model_len - self._vocab_size = vocab_size - - def get_proposals( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, - ) -> SpeculativeProposals: - """Get speculative proposals given the input batch. - - Sequences which would exceed the max model length are skipped during - speculation. - """ - - # Split speculative- and non-speculative- sequences. - (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) = self._split_by_max_model_len( - seq_group_metadata_list, max_proposal_len) - - if nonzero_proposal_len_seqs: - # Speculate tokens using the draft worker for the speculative - # sequences. - maybe_sampler_output = self._draft_worker.execute_model_multi_step( - seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_steps=max_proposal_len, - ) - else: - # If no sequences can be speculated, set sampler output to None. - maybe_sampler_output = None - - # Combine speculative- and non-speculative sequences into the same - # representation. - proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), - max_proposal_len=max_proposal_len, - maybe_sampler_output=maybe_sampler_output, - proposal_lens=proposal_lens, - nonzero_proposal_len_indices=nonzero_proposal_len_indices, - ) - - proposals = SpeculativeProposals( - proposal_token_ids=proposal_tokens, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens, - ) - - return proposals - - def _split_by_max_model_len( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - max_proposal_len: int, - ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length. - """ - - proposal_lens: List[int] = [] - nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] - nonzero_proposal_len_indices: List[int] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) - seq_len = seq_data.get_len() - - # Currently only proposal lens of 0 or the global batch proposal len - # are supported. - if seq_len + max_proposal_len < self._max_model_len: - proposal_lens.append(max_proposal_len) - nonzero_proposal_len_seqs.append(seq_group_metadata) - nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) - - return (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) - - def _merge_outputs( - self, - batch_size: int, - max_proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], - proposal_lens: List[int], - nonzero_proposal_len_indices: List[int], - ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: - """After speculations are produced, merge the speculation results with - the skipped sequences. - """ - if maybe_sampler_output is None: - # If no speculative tokens, the sampler output will be None. - # In this case we return empty tensors. - proposal_tokens = torch.zeros(0, - max_proposal_len, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(0, - max_proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device) - proposal_lens = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens - - sampler_output = maybe_sampler_output - - proposal_tokens, proposal_probs = sampler_output_to_torch( - sampler_output) - - # Now, reformat the output GPU tensors such that each sequence has - # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - - entire_proposal_tokens = torch.full(size=(batch_size, - *proposal_tokens.shape[1:]), - fill_value=-1, - dtype=torch.long, - device=self._device) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros(batch_size, - *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device) - entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs - - proposal_tokens, proposal_probs = (entire_proposal_tokens, - entire_proposal_probs) - - proposal_lens = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens[nonzero_proposal_len_indices] = max_proposal_len - - return proposal_tokens, proposal_probs, proposal_lens diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000000000..c2b22f2acd7b4 --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,186 @@ +import weakref +from typing import List, Optional, Tuple + +import torch + +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class NGramWorker(LoraNotSupportedWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implement prompt lookup decoding, + and in future we may also do RAG type drafter and other scenerios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["model_config"].get_vocab_size() + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"cuda:{self.local_rank}") + self.load_model = lambda *args, **kwargs: None + + # Current only support Top1Proposer + self._proposer = Top1Proposer( + weakref.proxy(self), + device=self.device, + vocab_size=self.vocab_size, + ) + + def set_include_gpu_probs_tensor(self): + # NGram don't need gpu sampler + pass + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None) -> None: + """NGram doesn't depend on model execution, just pass this function""" + pass + + def determine_num_available_blocks(self) -> None: + """NGram doesn't depend on model execution, no need to check blocks""" + pass + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """As there is no cache need to handle, just pass this function""" + pass + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes.""" + return 0 + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + has_spec_out = False + token_id_list = [] + token_prob_list = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=self.device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min - 1, + -1, + ): + ngram_tensor = input_ids[-ngram_size:] + proposal_start_idx = None + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + # Note that "first_match.values.item()" triggers GPU-CPU + # sync so it is a bit inefficient, but we have not found + # a better way to do this. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=self.device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, index=spec_indices) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) + has_spec_out = True + break + else: + token_id_list.append(None) + token_prob_list.append(None) + + if not has_spec_out: + return None, False + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) + + return outputs, False + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_proposals(execute_model_req) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 885bf537568e3..150e8db0c8aad 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,20 +1,58 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch +from vllm.distributed.communication_op import broadcast_tensor_dict +from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.util import (create_sequence_group_output, + get_all_num_logprobs, get_all_seq_ids, + get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) + + +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + assert "speculative_config" in kwargs + speculative_config = kwargs.get("speculative_config") + assert speculative_config is not None + + target_worker = Worker(*args, **kwargs) + + draft_worker_kwargs = kwargs.copy() + # Override draft-model specific worker args. + draft_worker_kwargs.update( + model_config=speculative_config.draft_model_config, + parallel_config=speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + # TODO allow draft-model specific load config. + #load_config=load_config, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + ) + + return spec_decode_worker class SpecDecodeWorker(LoraNotSupportedWorkerBase): @@ -45,12 +83,45 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. """ + @classmethod + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs: Dict[str, Any], + disable_by_batch_size: Optional[int], + ) -> "SpecDecodeWorker": + + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + + disable_bonus_tokens = True + if ngram_prompt_lookup_max > 0: + disable_bonus_tokens = False + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, )) + def __init__( self, - proposer_worker: MultiStepWorker, - scorer_worker: Worker, + proposer_worker: WorkerBase, + scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, ): """ Create a SpecDecodeWorker. @@ -63,11 +134,14 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self._metrics = AsyncMetricsCollector( @@ -77,7 +151,8 @@ def __init__( self.probs_dtype = self.rejection_sampler.probs_dtype self.token_id_dtype = self.rejection_sampler.token_id_dtype - self.scorer: SpeculativeScorer = None + # Lazy initiazliation. + self.scorer: SpeculativeScorer def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -87,6 +162,10 @@ def init_device(self) -> None: self.scorer_worker.init_device() self.proposer_worker.init_device() + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + self._metrics.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( @@ -94,6 +173,34 @@ def init_device(self) -> None: device=self.device, vocab_size=self._vocab_size) + self._configure_model_sampler_for_spec_decode() + + def load_model(self, *args, **kwargs): + pass + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of rejection sampling. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + self.proposer_worker.set_include_gpu_probs_tensor() + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -127,79 +234,144 @@ def initialize_cache(self, num_gpu_blocks: int, @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - num_spec_tokens: int, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - - assert seq_group_metadata_list is not None, ( - "speculative decoding " - "requires non-None seq_group_metadata_list") - - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. - if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: - return self._run_no_spec( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - return self._run_speculative_decoding_step( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - k=num_spec_tokens, + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] + + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] + + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform then. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + disable_all_speculation=disable_all_speculation, ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + disable_all_speculation = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + + return disable_all_speculation + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to the - proposer and scorer model so that the KV cache is consistent between the - two. + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. """ + if not skip_proposer: + self.proposer_worker.execute_model(execute_model_req) - self.proposer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - return_python_output=False) - - sampler_output = self.scorer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + sampler_output = self.scorer_worker.execute_model(execute_model_req) + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. sampler_output.probs = None sampler_output.sampled_tokens = None + sampler_output.logprobs = None return [sampler_output] + def _run_non_driver_rank(self) -> bool: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + + Returns True iff there are remaining sequences to process. + """ + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] + + # Even if num_lookahead_slots is zero, we want to run the proposer model + # as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we should + # delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + self.scorer_worker.execute_model() + return True + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, - ) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -208,26 +380,25 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ + assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, k) + proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposal_scores = self.scorer.score_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - k, + execute_model_req, proposals, ) - accepted_token_ids = self._verify_tokens(seq_group_metadata_list, - proposal_scores, proposals, k) + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) - return self._create_output_sampler_list(seq_group_metadata_list, - accepted_token_ids, k) + return self._create_output_sampler_list( + execute_model_req.seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + k=execute_model_req.num_lookahead_slots) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -236,9 +407,12 @@ def _verify_tokens( proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens.tolist() @@ -256,15 +430,26 @@ def _verify_tokens( select_proposal_len_zero=True) original_indices = spec_indices + non_spec_indices - proposal_probs = proposal_scores.probs[spec_indices, :-1] - bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + # Get probabilities of target model, excluding bonus token. + proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + + # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + accepted_token_ids = self.rejection_sampler( - proposal_probs, - bonus_token_ids, - proposals.proposal_probs, - proposals.proposal_token_ids, + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, ) # Append output tokens from non-speculative sequences to @@ -274,17 +459,19 @@ def _verify_tokens( non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() - return accepted_token_ids + return accepted_token_ids, logprobs def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -292,30 +479,68 @@ def _create_output_sampler_list( The output is padded with -1 tokens such that each sequence has the same number of outputs. """ + batch_size, num_steps = accepted_token_ids.shape + + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + + # Get the top-k logprobs (which may or may not include the logprob of + # the accepted token). + (topk_logprobs_by_step, + topk_indices_by_step) = target_logprobs_by_step.topk( + k=self.scorer_worker.model_config.max_logprobs, + dim=-1, + ) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. seq_ids = get_all_seq_ids(seq_group_metadata_list) - - # shape: [k+1, batch_size] - accepted_token_ids_by_step = accepted_token_ids.transpose(0, - 1).tolist() + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize all tensors to CPU Python lists. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step.tolist() + topk_indices_by_step = topk_indices_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. sampler_output_list = [] - for token_ids_by_step in accepted_token_ids_by_step: - if all(token_id == -1 for token_id in token_ids_by_step): + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] - for token_id, seq_id in zip(token_ids_by_step, seq_ids): + for sequence_index in range(batch_size): + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( - SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - # TODO Add verifier logprobs. - logprobs={token_id: 0.0}, - ) - ], - prompt_logprobs=None, + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], )) + sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) @@ -347,6 +572,10 @@ def rank(self): def device(self): return self.scorer_worker.device + @property + def _driver_rank(self) -> int: + return 0 + def get_cache_block_size_bytes(self): """Return the size of a cache block in bytes. diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000000000..6c7e22207f6b2 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,274 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.util import sampler_output_to_torch +from vllm.worker.worker_base import WorkerBase + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: WorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + nonzero_execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=nonzero_proposal_len_seqs, + num_lookahead_slots=proposal_len, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, + sample_len=proposal_len, + ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + ) + + return proposals + + def _split_by_proposal_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall no exccess this + # quota for nonzero_proposal + new_k = 0 + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + new_k = proposal_len + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[SamplerOutput], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.full( + size=( + batch_size, + proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + proposal_probs = torch.zeros( + batch_size, + proposal_len, + self._vocab_size, + dtype=torch.float32, + device=self._device, + ) + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = torch.full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = torch.zeros( + batch_size, + *proposal_probs.shape[1:], + dtype=torch.float32, + device=self._device, + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 406568a4bc08c..4dc6c49eb58d2 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,10 +1,12 @@ from contextlib import contextmanager from itertools import chain -from typing import List, Tuple +from typing import Dict, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) SeqId = int @@ -21,6 +23,89 @@ def get_all_seq_ids( ])) +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if seq_group_metadata.sampling_params.logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor >= + expanded_selected_logprobs).sum(-1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[int], + topk_logprobs: List[float], +) -> SequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[int]): The list of top-k token ids. + topk_logprobs (List[float]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_ids[topk_logprob_index]: Logprob( + logprob=topk_logprobs[topk_logprob_index], + rank=topk_logprob_index + 1, + ) + for topk_logprob_index, _ in enumerate(topk_token_ids) + }) + + return CompletionSequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + # TODO add prompt logprobs support. + prompt_logprobs=None, + ) + + def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], proposal_lens: List[int], select_proposal_len_zero: bool @@ -49,10 +134,13 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], -) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + Returns: sampled_token_ids: torch.Tensor shape: [batch_size, len(sampler_output_list)] @@ -68,7 +156,19 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + if sampler_transposed: + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( @@ -77,9 +177,37 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + if sampler_transposed: + sampled_token_ids = sampled_token_ids.transpose(0, 1) + + return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + - return sampled_token_ids, sampled_token_probs +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) @contextmanager diff --git a/vllm/test_utils.py b/vllm/test_utils.py deleted file mode 100644 index bc220d3b8a430..0000000000000 --- a/vllm/test_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import ray - -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) -from vllm.utils import get_open_port - - -def init_test_distributed_environment( - pipeline_parallel_size: int, - tensor_parallel_size: int, - rank: int, - distributed_init_port: str, - local_rank: int = -1, -) -> None: - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - init_distributed_environment( - world_size=pipeline_parallel_size * tensor_parallel_size, - rank=rank, - distributed_init_method=distributed_init_method, - local_rank=local_rank) - ensure_model_parallel_initialized(tensor_parallel_size, - pipeline_parallel_size) - - -def multi_process_tensor_parallel( - tensor_parallel_size: int, - test_target, -) -> None: - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - ray.init() - - distributed_init_port = get_open_port() - refs = [] - for rank in range(tensor_parallel_size): - refs.append( - test_target.remote(tensor_parallel_size, rank, - distributed_init_port)) - ray.get(refs) - - ray.shutdown() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8a6ba6c5b396c..044eec6410a54 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,10 +1,14 @@ -from typing import Optional +from typing import Dict, Optional from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import * +from vllm.logger import init_logger +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, MPTConfig, RWConfig) -_CONFIG_REGISTRY = { +logger = init_logger(__name__) + +_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, @@ -17,7 +21,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, @@ -40,6 +45,10 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + if rope_scaling is not None: + logger.info("Updating rope_scaling from %r to %r", + getattr(config, "rope_scaling", None), rope_scaling) + config.update({"rope_scaling": rope_scaling}) return config @@ -54,4 +63,4 @@ def get_hf_text_config(config: PretrainedConfig): assert hasattr(config.text_config, "num_attention_heads") return config.text_config else: - return config + return config \ No newline at end of file diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000000000..7780bf5e78d6d --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index c4244f8c77f44..49d2b8d8e21b1 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -46,6 +46,8 @@ def __init__(self, self.kv_channels = kv_channels self.num_attention_heads = num_attention_heads self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.layernorm_epsilon = layernorm_epsilon diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py index 3a19af7129e73..0dc9664723d34 100644 --- a/vllm/transformers_utils/configs/dbrx.py +++ b/vllm/transformers_utils/configs/dbrx.py @@ -12,7 +12,7 @@ logger = logging.get_logger(__name__) -DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore class DbrxAttentionConfig(PretrainedConfig): @@ -72,9 +72,10 @@ def from_pretrained( and config_dict["model_type"] != cls.model_type ): logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], cls.model_type) return cls.from_dict(config_dict, **kwargs) @@ -151,9 +152,9 @@ def from_pretrained( and config_dict["model_type"] != cls.model_type ): logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 94f438716f8bf..b06a946f34a47 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -222,13 +222,15 @@ def _alibi_scaling_validation(self): f"got {alibi_scaling_type}") if (alibi_scaling_factor is not None and not isinstance(alibi_scaling_factor, float) - or alibi_scaling_factor <= 1.0): + or (alibi_scaling_factor is not None + and alibi_scaling_factor <= 1.0)): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0," f"got {alibi_scaling_factor}") if (alibi_dynamic_scaling is not None and not isinstance(alibi_dynamic_scaling, int) - or alibi_dynamic_scaling <= 1): + or (alibi_dynamic_scaling is not None + and alibi_dynamic_scaling <= 1)): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an" f"integer > 1, got {alibi_dynamic_scaling}") diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 486c1938e1e10..f064c26c3f40c 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace( prev_tokens.extend(next_iter_tokens) def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> None: + prms: SamplingParams) -> int: """Decodes the new token for a sequence. In-place operation. Args: seq: The sequence to decode. prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] @@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence, seq.read_offset = read_offset seq.output_text += new_decoded_token_text + return len(new_decoded_token_text) + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], @@ -163,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] + sub_texts: List[str] = [] + current_sub_text: List[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -258,6 +263,7 @@ def detokenize_incrementally( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): @@ -266,6 +272,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f9..f5684dbf1271c 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,11 +1,14 @@ +import os from typing import Optional, Union +import huggingface_hub from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * +from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async logger = init_logger(__name__) @@ -28,7 +31,7 @@ def get_cached_tokenizer( tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) - class CachedTokenizer(tokenizer.__class__): + class CachedTokenizer(tokenizer.__class__): # type: ignore @property def all_special_ids(self): @@ -56,10 +59,29 @@ def get_tokenizer( *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tokenizer_revision: Optional[str] = None, + revision: Optional[str] = None, + download_dir: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface.""" + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError( @@ -71,7 +93,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) except ValueError as e: # If the error pertains to the tokenizer class not existing or not @@ -95,7 +117,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) else: raise e @@ -118,9 +140,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, # No tokenizer was found in the LoRA folder, # use base model tokenizer logger.warning( - f"No tokenizer found in {lora_request.lora_local_path}, " - "using base model tokenizer instead. " - f"(Exception: {str(e)})") + "No tokenizer found in %s, using base model tokenizer instead. " + "(Exception: %s)", lora_request.lora_local_path, e) tokenizer = None return tokenizer diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index a3b979e8fbc13..0195c40c27f60 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( @@ -11,7 +11,7 @@ from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( RayTokenizerGroupPool) else: - RayTokenizerGroupPool = None + RayTokenizerGroupPool = None # type: ignore def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index c00b02fdbbbc0..7c605416854b8 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -89,6 +89,7 @@ def encode(self, This is blocking. """ self._ensure_queue_initialized() + assert self._idle_actors is not None if self._idle_actors.empty(): raise RuntimeError("No idle actors available.") @@ -120,6 +121,7 @@ async def encode_async( This is non-blocking. """ self._ensure_queue_initialized() + assert self._idle_actors is not None actor = await self._idle_actors.get() try: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 927cbeed073bf..9614f01d2b955 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -34,12 +34,26 @@ def get_max_input_len(self, """Get the maximum input length for the LoRA request.""" return self.max_input_length + def _raise_if_input_too_long(self, + encoded_tokens: List[str], + lora_request: Optional[LoRARequest] = None): + input_length = len(encoded_tokens) + if lora_request: + max_input_length = (lora_request.long_lora_max_len + or self.max_input_length) + else: + max_input_length = self.max_input_length + if max_input_length is not None and input_length > max_input_length: + raise ValueError("Input too long.", input_length, max_input_length) + def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret async def encode_async( self, @@ -47,7 +61,9 @@ async def encode_async( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret def get_lora_tokenizer( self, diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py index 02045bdcb2ccf..76daabc41e0a2 100644 --- a/vllm/transformers_utils/tokenizers/baichuan.py +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -16,11 +16,11 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} -PRETRAINED_VOCAB_FILES_MAP = { +PRETRAINED_VOCAB_FILES_MAP = { # type: ignore "vocab_file": {}, "tokenizer_file": {}, } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore class BaichuanTokenizer(PreTrainedTokenizer): @@ -114,9 +114,9 @@ def _convert_id_to_token(self, index): token = self.sp_model.IdToPiece(index) return token - def convert_tokens_to_string(self, tokens): + def convert_tokens_to_string(self, tokens: List[str]): """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] + current_sub_tokens: List[str] = [] out_string = "" prev_is_special = False for i, token in enumerate(tokens): @@ -148,9 +148,9 @@ def save_vocabulary(self, `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) " - "should be a directory") - return + raise ValueError(f"Vocabulary path ({save_directory}) " + "should be a directory") + out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 658fe5c98f5ee..40a954a29493e 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -7,7 +7,7 @@ from enum import Enum from pathlib import Path from threading import Thread -from typing import Dict, Optional +from typing import Any, Dict, Optional from uuid import uuid4 import cpuinfo @@ -15,20 +15,22 @@ import requests import torch -_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) +import vllm.envs as envs + +_config_home = envs.VLLM_CONFIG_ROOT _USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json") _USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "vllm/do_not_track") _USAGE_STATS_ENABLED = None -_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER", - "https://stats.vllm.ai") +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER def is_usage_stats_enabled(): """Determine whether or not we can send usage stats to the server. The logic is as follows: - By default, it should be enabled. - - Two environment variables can disable it: + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 - DO_NOT_TRACK=1 - VLLM_NO_USAGE_STATS=1 - A file in the home directory can disable it if it exists: @@ -36,8 +38,8 @@ def is_usage_stats_enabled(): """ global _USAGE_STATS_ENABLED if _USAGE_STATS_ENABLED is None: - do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1" - no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1" + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats @@ -88,6 +90,7 @@ class UsageContext(str, Enum): LLM_CLASS = "LLM_CLASS" API_SERVER = "API_SERVER" OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" ENGINE_CONTEXT = "ENGINE_CONTEXT" @@ -124,7 +127,7 @@ def __init__(self) -> None: def report_usage(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any] = None) -> None: + extra_kvs: Optional[Dict[str, Any]] = None) -> None: t = Thread(target=self._report_usage_worker, args=(model_architecture, usage_context, extra_kvs or {}), daemon=True) @@ -132,13 +135,13 @@ def report_usage(self, def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: # Platform information if torch.cuda.is_available(): device_property = torch.cuda.get_device_properties(0) @@ -167,7 +170,7 @@ def _report_usage_once(self, model_architecture: str, # Metadata self.log_time = _get_current_timestamp_ns() - self.source = os.environ.get("VLLM_USAGE_SOURCE", "production") + self.source = envs.VLLM_USAGE_SOURCE data = vars(self) if extra_kvs: diff --git a/vllm/utils.py b/vllm/utils.py index 8ba03333d3b6c..2781eceb7ba98 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,22 +1,28 @@ import asyncio +import datetime import enum import gc import os import socket import subprocess +import sys +import tempfile +import threading import uuid import warnings -from collections import OrderedDict, defaultdict -from functools import lru_cache, partial +from collections import defaultdict +from functools import lru_cache, partial, wraps from platform import uname -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, - Optional, Tuple, TypeVar, Union) +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, + Hashable, List, Optional, OrderedDict, Tuple, TypeVar, + Union) +import numpy as np import psutil import torch -from packaging.version import Version, parse -from vllm.logger import init_logger +import vllm.envs as envs +from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") logger = init_logger(__name__) @@ -26,6 +32,8 @@ "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, } @@ -51,7 +59,7 @@ def reset(self) -> None: class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict[Hashable, T]() + self.cache: OrderedDict[Hashable, T] = OrderedDict() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: @@ -60,7 +68,7 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: Hashable) -> Optional[T]: return self.get(key) def __setitem__(self, key: Hashable, value: T) -> None: @@ -76,7 +84,7 @@ def get(self, key: Hashable, default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: - value = self.cache[key] + value: Optional[T] = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -87,7 +95,7 @@ def put(self, key: Hashable, value: T) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): pass def remove_oldest(self): @@ -100,9 +108,11 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache - value = self.cache.pop(key, default_value) + value: Optional[T] = self.cache.pop(key, default_value) if run_on_remove: self._on_remove(key, value) return value @@ -159,6 +169,17 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +@lru_cache(maxsize=None) +def get_vllm_instance_id(): + """ + If the environment variable VLLM_INSTANCE_ID is set, return it. + Otherwise, return a random UUID. + Instance id represents an instance of the VLLM. All processes in the same + instance should have the same instance id. + """ + return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" + + @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 @@ -181,8 +202,53 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: return _async_wrapper +def merge_async_iterators( + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + try: + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() + raise e + await asyncio.gather(*_tasks) + + return consumer() + + def get_ip() -> str: - host_ip = os.environ.get("HOST_IP") + host_ip = envs.VLLM_HOST_IP if host_ip: return host_ip @@ -208,7 +274,8 @@ def get_ip() -> str: warnings.warn( "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable HOST_IP.", + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", stacklevel=2) return "0.0.0.0" @@ -220,6 +287,9 @@ def get_distributed_init_method(ip: str, port: int) -> str: def get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + return port # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -232,8 +302,13 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " + "from '%s' to '%s'", k, os.environ[k], v) + os.environ[k] = v def chunk_list(lst, chunk_size): @@ -246,26 +321,6 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) -@lru_cache(maxsize=None) -def get_nvcc_cuda_version() -> Optional[Version]: - cuda_home = os.environ.get('CUDA_HOME') - if not cuda_home: - cuda_home = '/usr/local/cuda' - if os.path.isfile(cuda_home + '/bin/nvcc'): - logger.info(f'CUDA_HOME is not found in the environment. ' - f'Using {cuda_home} as CUDA_HOME.') - else: - logger.warning( - f'Not found nvcc in {cuda_home}. Skip cuda version check!') - return None - nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], - universal_newlines=True) - output = nvcc_output.split() - release_idx = output.index("release") + 1 - nvcc_cuda_version = parse(output[release_idx].split(",")[0]) - return nvcc_cuda_version - - def _generate_random_fp8( tensor: torch.tensor, low: float, @@ -279,28 +334,16 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm._C import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = 0, - device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): @@ -319,6 +362,55 @@ def create_kv_caches_with_random( torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() @@ -372,7 +464,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True @@ -401,7 +492,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() -def str_to_int_tuple(s: str) -> Tuple[int]: +def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: return tuple(map(int, s.split(","))) @@ -411,11 +502,6 @@ def str_to_int_tuple(s: str) -> Tuple[int]: f"(e.g., 1, 2, 3). Given input: {s}") from e -def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]: - assert len(x) <= max_len - return x + [pad] * (max_len - len(x)) - - def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -428,7 +514,10 @@ def make_tensor_with_pad( The padding is applied to the end of each inner list until it reaches `max_len`. """ - padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x] + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) @@ -455,7 +544,7 @@ def maybe_expand_dim(tensor: torch.Tensor, def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. - + When a key conflicts, the values in dict1 is prioritized. """ merged_dict = defaultdict(list) @@ -467,3 +556,121 @@ def merge_dicts(dict1: Dict[Any, List[Any]], merged_dict[key].extend(value) return dict(merged_dict) + + +def init_cached_hf_modules(): + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + + +@lru_cache(maxsize=None) +def find_library(lib_name: str) -> str: + """ + Find the library file in the system. + `lib_name` is full filename, with both prefix and suffix. + This function resolves `lib_name` to the full path of the library. + """ + # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa + # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard + # `/sbin/ldconfig` should exist in all Linux systems. + # `/sbin/ldconfig` searches the library in the system + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] + # `LD_LIBRARY_PATH` searches the library in the user-defined paths + env_ld_library_path = envs.LD_LIBRARY_PATH + if not locs and env_ld_library_path: + locs = [ + os.path.join(dir, lib_name) + for dir in env_ld_library_path.split(":") + if os.path.exists(os.path.join(dir, lib_name)) + ] + if not locs: + raise ValueError(f"Cannot find {lib_name} in the system.") + return locs[0] + + +def find_nccl_library(): + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = envs.VLLM_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +def enable_trace_function_call_for_thread() -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if envs.VLLM_TRACE_FUNCTION: + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + + +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee0648626b..2f0e59f7ae7c9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,5 +1,5 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List +from typing import List import torch @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,28 +64,31 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. kv_cache.append( - torch.empty(kv_cache_shape, + torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device)) return kv_cache - def swap_in(self, src_to_dst: Dict[int, int]) -> None: + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) - def swap_out(self, src_to_dst: Dict[int, int]) -> None: + def swap_out(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..bc88f2c5bed6c --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,346 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.vision_language_config = vision_language_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + def load_model(self) -> None: + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + Optional[torch.Tensor]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_input_list: List[torch.Tensor] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, seq_len))) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + seq_lens=seq_lens, + seq_lens_tensor=None, + max_decode_seq_len=None, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + ) + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_input) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Optional[torch.Tensor]]: + multi_modal_input = None + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + pin_memory=False) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + seq_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + ) + + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, multi_modal_input) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata, + multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 42f0828b826e2..3ee394f9912e9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,37 +1,26 @@ """A CPU worker class.""" -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend. @@ -64,7 +53,15 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) @@ -129,10 +126,12 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -141,25 +140,35 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.vision_language_config = vision_language_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = CPUModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + self.model_runner = CPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.cpu_cache = None + self.cache_engine: CPUCacheEngine + self.cpu_cache: List[torch.Tensor] def init_device(self) -> None: self.init_distributed_environment() @@ -169,7 +178,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. This determines how many KV blocks can fit into the configured CPU @@ -248,30 +257,35 @@ def _init_cache_engine(self) -> None: def cache_copy( self, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None - num_seq_groups = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - assert len(blocks_to_swap_in) == 0 - assert len(blocks_to_swap_out) == 0 - data = { + num_seq_groups: int = len(seq_group_metadata_list) + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, - "blocks_to_copy": blocks_to_copy, + "blocks_to_copy": execute_model_req.blocks_to_copy, } broadcast_tensor_dict(data, src=0) else: @@ -283,11 +297,13 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.cpu_cache) - return output + + # CPU worker only supports single-step execution. + return [output] def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py new file mode 100644 index 0000000000000..0ba1200696cab --- /dev/null +++ b/vllm/worker/embedding_model_runner.py @@ -0,0 +1,170 @@ +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + +logger = init_logger(__name__) + + +class EmbeddingModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[PoolerOutput]: + (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return None + + return self.model.pooler(hidden_states=hidden_states, + pooling_metadata=pooling_metadata) + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, + Set[LoRARequest], LoRAMapping, torch.Tensor]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + _, + lora_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + seq_lens) + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_input = metadata_dict.pop("multi_modal_input") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + attn_metadata = None + pooling_metadata = PoolingMetadata(seq_groups=None, + seq_data=None, + prompt_lens=None) + + return (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a7..47aa70dc617af 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,31 +1,28 @@ -import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +import warnings +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - with_pynccl_for_all_reduce) -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) logger = init_logger(__name__) @@ -39,6 +36,38 @@ ] +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + lora_mapping: Optional[LoRAMapping] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int + + @classmethod + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_mapping=None, + lora_requests=set(), + multi_modal_input=None, + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, + ) + + class ModelRunner: def __init__( @@ -47,6 +76,8 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, @@ -55,54 +86,65 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config + self.load_config = load_config self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() - self.model = None - self.block_size = None # Set after initial profiling. - self.lora_manager = None - + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. - - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in + # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = None # Set after initial profiling. - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Lazy initialization + self.model: nn.Module # Set after load_model + # Set if the backend is flashinfer. + self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( - self.model_config, - self.device_config, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) self.model_memory_usage = m.consumed_memory - logger.info(f"Loading model weights took " - f"{self.model_memory_usage / float(2**30):.4f} GB") + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) if self.lora_config: assert hasattr(self.model, "supported_lora_modules" @@ -115,305 +157,321 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config. + max_position_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently scaled KV cache is only enabled on ROCm + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. if self.model_config.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) self.model.load_kv_cache_scales( self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) else: - raise RuntimeError("Using FP8 KV cache and scaling " - "factors provided but model " - f"{self.model.__class__} does not " - "support loading scaling factors.") + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__) else: - logger.warn("Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - elif self.model_config.quantization_param_path is not None: - logger.warn("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") - - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_context_len_to_capture + block_size - 1) // block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_prompt( + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - - prompt_lens: List[int] = [] - context_lens: List[int] = [] - subquery_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() - prompt_lens.append(prompt_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] - prefix_block_tables.append(computed_block_nums) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 - - # actual prompt lens - context_lens.append(computed_len) - subquery_lens.append(prompt_len - computed_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (prompt_len - computed_len) - lora_prompt_mapping.extend( - [lora_id] * - (prompt_len - computed_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert computed_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, prompt_len - self.sliding_window) - - for i in range(computed_len, prefill_end): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) - assert max_subquery_len > 0 - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - subquery_lens_tensor = torch.tensor(subquery_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) + ) -> ModelInput: + """Prepare the model input based on a given sequence group. - torch.cumsum(subquery_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) + The API assumes seq_group_metadata_list is sorted by prefill -> decode. - torch.cumsum(prompt_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) + The result tensors and data structure also batches input in prefill + -> decode order. For example, - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - slot_mapping=slot_mapping, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, - max_subquery_len=max_subquery_len, - max_context_len=None, - max_prompt_len=max_prompt_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, - ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + If cuda graph is required, this API automatically pads inputs. + """ input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] - block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 + seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + block_tables: List[List[int]] = [] + multi_modal_input_list: List[torch.Tensor] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + + if len(seq_group_metadata_list) == 0: + return ModelInput.empty(self.device) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window + self.block_size - + 1) // self.block_size + block_aligned_sliding_window = \ + sliding_window_blocks * self.block_size + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) + is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(sliding_seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * query_len + lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + if is_prompt: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) - # vLLM uses cuda graph only for decoding requests. - # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_context_len = max(context_lens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. use_captured_graph = ( - not self.model_config.enforce_eager + decode_only and not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + and max_decode_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -421,31 +479,13 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + num_decode_tokens = batch_size if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] - # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -463,160 +503,158 @@ def _prepare_decode( dtype=torch.int, device=self.device, ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - prompt_lens=None, - prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens=context_lens, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, - ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert subquery_lens is not None - subquery_len = subquery_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += subquery_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + subquery_len - 1)) - selected_token_indices.append(selected_token_start_idx + - subquery_len - 1) - selected_token_start_idx += subquery_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + if self.attn_backend.get_name() == "flashinfer": + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=16, + seq_start_loc=seq_start_loc, + data_type=kv_cache_dtype) + else: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, ) - return sampling_metadata def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + assert seq_group_metadata_list is not None # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + lora_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) - # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -625,8 +663,13 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) @@ -637,15 +680,16 @@ def prepare_input_tensors( lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, - seq_data=None, - prompt_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, - generators=None, - perform_sampling=False, + num_prompts=0, ) return (input_tokens, input_positions, attn_metadata, @@ -665,8 +709,10 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -685,7 +731,7 @@ def execute_model( logits = self.model.compute_logits(hidden_states, sampling_metadata) # Only perform sampling in the driver worker. - if not sampling_metadata.perform_sampling: + if not self.is_driver_worker: return None # Sample the next token. @@ -693,6 +739,7 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) + return output @torch.inference_mode() @@ -701,7 +748,6 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -709,20 +755,22 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. @@ -762,12 +810,12 @@ def profile_run(self) -> None: torch.cuda.synchronize() return - def remove_all_loras(self) -> bool: + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_loras() - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -802,10 +850,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -824,7 +868,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -833,33 +877,26 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - with custom_all_reduce.capture(): + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, slot_mapping=slot_mapping[:batch_size], - prompt_lens=None, - prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, - max_subquery_len=None, - max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, - subquery_start_loc=None, + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, seq_start_loc=None, - context_lens=context_lens[:batch_size], + context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: @@ -876,6 +913,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -883,17 +921,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: end_time = time.perf_counter() elapsed_time = end_time - start_time # This usually takes < 10 seconds. - logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") - - def __del__(self) -> None: - # Delete the CUDA graphs before deleting the pynccl communicator. - # NOTE(woosuk): This is necessary because otherwise deadlocks can - # happen. - # FIXME(woosuk): This is a bit hacky. Find a more robust solution. - # TODO(youkaichao): when we get enough user feedback that pynccl is - # more stable than cupy, we can remove this, e.g. in v0.4.1. - self.graph_runners.clear() - self.pynccl_backend = None + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) @property def vocab_size(self) -> int: @@ -904,25 +932,43 @@ class CUDAGraphRunner: def __init__(self, model: nn.Module): self.model = model - self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - memory_pool, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, **kwargs, ) -> None: - assert self.graph is None + assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): - self.model( + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + hidden_states = self.model( input_ids, positions, kv_caches, @@ -931,29 +977,14 @@ def capture( ) torch.cuda.synchronize() - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - **kwargs, - ) - torch.cuda.synchronize() - # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -974,10 +1005,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() @@ -988,16 +1019,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. @@ -1029,3 +1050,15 @@ def _prepare_fake_inputs( prompt_tokens = [0] * seq_len fake_image_input = None return SequenceData(prompt_tokens), fake_image_input + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index fff721a80c204..a336be04e124f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,16 +1,15 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch +from torch import nn from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata -from vllm.model_executor.neuron_model_loader import get_neuron_model -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (async_tensor_h2d, is_pin_memory_available, - make_tensor_with_pad, maybe_expand_dim) +from vllm.model_executor.model_loader.neuron import get_neuron_model +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) @@ -34,9 +33,11 @@ def __init__( self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None self.pin_memory = is_pin_memory_available() + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + def load_model(self) -> None: self.model = get_neuron_model(self.model_config, parallel_config=self.parallel_config, @@ -51,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -60,26 +61,26 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_prompt_len = max(prompt_lens) - assert max_prompt_len > 0 + max_seq_len = max(seq_lens) + assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) @@ -87,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, prompt_lens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -138,106 +139,9 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert prompt_lens is not None - prompt_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -245,13 +149,20 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory) return (input_tokens, input_positions, input_block_ids, sampling_metadata) @@ -259,7 +170,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata ) = self.prepare_input_tensors(seq_group_metadata_list) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 6136d50d0c068..d0e6aaed180e6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional +from typing import List, Tuple import torch import torch.distributed @@ -29,6 +29,10 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) @@ -40,7 +44,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. Swapping is not yet supported, so always return num_cpu_blocks=0. @@ -73,15 +77,18 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: num_seq_groups = len(seq_group_metadata_list) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list) - return output + + # Neuron worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1d0fe34125096..37d57de6c26d6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,24 +1,23 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import is_hip +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -38,11 +37,13 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -54,28 +55,39 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.load_config = load_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.vision_language_config = vision_language_config if self.vision_language_config: assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self.model_runner = ModelRunner( + ModelRunnerClass = (EmbeddingModelRunner if + self.model_config.embedding_mode else ModelRunner) + self.model_runner = ModelRunnerClass( model_config, parallel_config, scheduler_config, device_config, + cache_config, + load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.gpu_cache = None + self.cache_engine: CacheEngine + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[torch.tensor]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -108,6 +120,18 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many @@ -175,7 +199,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -186,56 +209,105 @@ def _warm_up_model(self) -> None: def cache_swap( self, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + blocks_to_swap_in: torch.Tensor, + blocks_to_swap_out: torch.Tensor, + blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if blocks_to_swap_in: + if blocks_to_swap_in.numel() > 0: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out: + if blocks_to_swap_out.numel() > 0: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - num_seq_groups = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - data = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[Union[SamplerOutput, PoolerOutput]]: + if not self.is_driver_worker: + self._execute_model_non_driver() + return [] + + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) - return output + + # Worker only supports single-step execution. Wrap the output in a list + # to conform to interface. + return [output] + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + """Execute model in parallel worker. + + Returns True iff there are remaining sequences to process. + """ + assert not self.is_driver_worker + data = broadcast_tensor_dict(src=0) + if not data: + return False + + num_seq_groups = data.get("num_seq_groups", 0) + blocks_to_swap_in = data.get("blocks_to_swap_in") + blocks_to_swap_out = data.get("blocks_to_swap_out") + blocks_to_copy = data.get("blocks_to_copy") + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return False + + self.model_runner.execute_model(None, self.gpu_cache) + return True def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) @@ -269,43 +341,18 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" - if not parallel_config.worker_use_torchrun: + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + if parallel_config.distributed_executor_backend != "torchrun": init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) else: init_distributed_environment(parallel_config.world_size, -1, "env://", local_rank) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1 and not is_hip(): - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(mattwong): We do not use pynccl on ROCm. - pynccl_utils.init_process_group( - world_size=parallel_config.world_size, - local_rank=local_rank, - rank=rank, - init_method=distributed_init_method, - ) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - # Initialize a custom fast all-reduce implementation. - if not parallel_config.disable_custom_all_reduce: - init_custom_ar() - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e3027c406ffeb..258f31de17d87 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,8 +1,15 @@ +import importlib +import os from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Optional, Set, Tuple +from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) + +logger = init_logger(__name__) class WorkerBase(ABC): @@ -18,14 +25,14 @@ def init_device(self) -> None: raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. The implementation may run profiling or other heuristics to determine the size of caches. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. @@ -40,16 +47,16 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" raise NotImplementedError @abstractmethod - def get_cache_block_size_bytes() -> int: + def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ @@ -64,7 +71,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @@ -79,5 +86,64 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise ValueError(f"{type(self)} does not support LoRA") - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__(self, + worker_module_name=None, + worker_class_name=None, + trust_remote_code: bool = False) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + self.worker = worker_class(*args, **kwargs) + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e