diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c1aebaf5b3bbe..fbf41eb10a392 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -34,17 +34,18 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan Performance benchmark will be triggered when: - A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. Nightly benchmark will be triggered when: -- Every commit for those PRs with `nightly-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details -See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. + +See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. #### Latency test @@ -68,7 +69,7 @@ Here is an example of one test inside `latency-tests.json`: In this example: - The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 8490c9f1da221..2b70e2da5d87c 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -21,7 +21,7 @@ steps: containers: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: limits: nvidia.com/gpu: 8 diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md similarity index 100% rename from .buildkite/nightly-benchmarks/tests/descriptions.md rename to .buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 534ecf17930e9..f90e464288cf1 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -174,8 +174,8 @@ def results_to_json(latency, throughput, serving): # document the result with open(results_folder / "benchmark_results.md", "w") as f: - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = read_markdown("../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md") results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh similarity index 90% rename from .buildkite/nightly-benchmarks/run-benchmarks-suite.sh rename to .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index f6e41fcfdd0be..a0b9a409b758d 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -37,9 +37,9 @@ check_hf_token() { ensure_sharegpt_downloaded() { local FILE=ShareGPT_V3_unfiltered_cleaned_split.json if [ ! -f "$FILE" ]; then - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE else - echo "$FILE already exists." + echo "$FILE already exists." fi } @@ -68,11 +68,29 @@ wait_for_server() { done' && return 0 || return 1 } +kill_processes_launched_by_current_bash() { + # Kill all python processes launched from current bash script + current_shell_pid=$$ + processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') + if [ -n "$processes" ]; then + echo "Killing the following processes matching '$1':" + echo "$processes" + echo "$processes" | xargs kill -9 + else + echo "No processes found matching '$1'." + fi +} + kill_gpu_processes() { - # kill all processes on GPU. - ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 - ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + ps -aux + lsof -t -i:8000 | xargs -r kill -9 + pkill -f pt_main_thread + # this line doesn't work now + # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 + pkill -f python3 + pkill -f /usr/bin/python3 + # wait until GPU memory usage smaller than 1GB while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do @@ -82,11 +100,6 @@ kill_gpu_processes() { # remove vllm config file rm -rf ~/.config/vllm - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" } upload_to_buildkite() { @@ -104,7 +117,7 @@ upload_to_buildkite() { fi # Use the determined command to annotate and upload artifacts - $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" } @@ -156,7 +169,7 @@ run_latency_tests() { latency_command: $latency, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$latency_command" @@ -166,7 +179,6 @@ run_latency_tests() { done } - run_throughput_tests() { # run throughput tests using `benchmark_throughput.py` # $1: a json file specifying throughput test cases @@ -214,7 +226,7 @@ run_throughput_tests() { throughput_command: $command, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$throughput_command" @@ -246,7 +258,6 @@ run_serving_tests() { continue fi - # get client and server arguments server_params=$(echo "$params" | jq -r '.server_parameters') client_params=$(echo "$params" | jq -r '.client_parameters') @@ -324,7 +335,7 @@ run_serving_tests() { client_command: $client, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" done @@ -341,6 +352,7 @@ main() { # dependencies (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which lsof) || (apt-get update && apt-get install -y lsof) # get the current IP address, required by benchmark_serving.py export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') @@ -359,7 +371,6 @@ main() { run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - # postprocess benchmarking results pip install tabulate pandas python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7babffc62f431..59d7241bd452d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -111,10 +111,10 @@ steps: commands: - pytest -v -s metrics - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" + 'opentelemetry-sdk>=1.26.0,<1.27.0' \ + 'opentelemetry-api>=1.26.0,<1.27.0' \ + 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'" - pytest -v -s tracing ##### fast check tests ##### @@ -311,6 +311,15 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py +- label: Multi-step Tests (4 GPUs) # 10min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/ + - tests/multi_step/test_correctness.py + commands: + - pytest -v -s multi_step/test_correctness.py + - label: Pipeline Parallelism Test # 23min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/.gitignore b/.gitignore index 2dfbf64d75c1b..761b00ac3bc48 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,9 @@ target/ profile_default/ ipython_config.py +# generated files +**/generated/** + # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: diff --git a/CMakeLists.txt b/CMakeLists.txt index d47f1bb305a96..c8d4aaeda9091 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -227,6 +227,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "-gencode arch=compute_90a,code=sm_90a") endif() + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH + ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + message(STATUS "Machete generation completed successfully.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}") + + # See comment above for scaled_mm_c3x (same if condition) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + ${MACHETE_GEN_SOURCES} + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + + # Add pytorch binding + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) endif() define_gpu_extension_target( diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 0000000000000..ca45cba6f8165 --- /dev/null +++ b/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,372 @@ +import argparse +import copy +import itertools +import math +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, pack_rows, quantize_weights) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + + +def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # make col major + return ops.machete_prepack_B(w_q, wtype) + + +def make_bench_tensors( + atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, + k: int +) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, + torch.tensor]]]: + assert wtype.is_integer(), "TODO: support floating point weights" + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) + + a = torch.randn((m, k), device="cuda", dtype=atype) * 5 + weights = [ + torch.randn((k, n), device="cuda", dtype=atype) + for _ in range(num_weights) + ] + quanitized_weights = [ + quantize_weights(w, wtype, group_size) for w in weights + ] + + return a, quanitized_weights + + +# impl + + +# bench +def bench_fn(label: str, sub_label: str, description: str, + fn: Callable) -> TMeasurement: + + min_run_time = 1 + return TBenchmark.Timer( + stmt="fn()", + globals={ + "fn": fn + }, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def loop_over_weights( + a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, + torch.tensor, torch.tensor]], + fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], + None]): + for w_ref, w_q, w_s, _ in weights: + fn(a, w_ref, w_q, w_s) + + +def bench(atype: torch.dtype, + wtype: ScalarType, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + benchmark_marlinv1: bool = True, + sweep_schedules: bool = True) -> Iterable[TMeasurement]: + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) + sub_label += f", L={len(weights)}" + + weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + timers = [] + # pytorch impl + timers.append( + bench_fn( + label, sub_label, "torch.matmul", lambda: loop_over_weights( + a, + weights, + lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), + ))) + + if benchmark_marlinv1: + w_ref = weights[0][0] + + w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) + sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) + g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) + + def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: + w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) + return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, + wtype.size_bits) + + def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: + return marlin_permute_scales(w_s, *w_ref.shape, group_size) + + weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), + marlinv1_permute_scales(w_s), w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + # marlinv1 + timers.append( + bench_fn( + label, sub_label, "marlin_orig", lambda: loop_over_weights( + a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. + gptq_marlin_gemm(a, + w_q, + w_s, + w_zp_empty, + g_idx, + sort_indices, + workspace.scratch, + wtype, + size_m=a.shape[0], + size_n=w_ref.shape[1], + size_k=w_ref.shape[0], + is_k_full=True)))) + + # machete + timers.append( + bench_fn( + label, sub_label, "machete_heuristic", lambda: loop_over_weights( + a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( + a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + + if sweep_schedules: + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules(wtype) + for schedule in reversed(schedules): + + def run(a, _, w_q, w_s, schedule=schedule): + ops.machete_gemm(a, + w_q, + wtype, + w_s, + b_group_size=group_size, + schedule=schedule) + + res = bench_fn(label, sub_label, "machete_best", + lambda: loop_over_weights(a, weights_machete, run)) + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, sweep_schedules: bool, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, + scalar_types.uint4b8, + 128, + m, + k, + n, + f"{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=sweep_schedules) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None, +): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "bfloat16": + return torch.bfloat16 + if dt == "float16": + return torch.float16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['bfloat16', 'float16']", + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 0000000000000..1d076ed6d5c18 --- /dev/null +++ b/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,64 @@ +import math +import pickle +import re +from collections import defaultdict +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('filename', type=str) + + args = parser.parse_args() + + with open(args.filename, 'rb') as f: + data: List[TMeasurement] = pickle.load(f) + + results = defaultdict(lambda: list()) + for v in data: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({ + "kernel": kernel, + "batch_size": M, + "median": v.median + }) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + axs_idx = 0 + for shape, data in results.items(): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot(data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2") + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + axs_idx += 1 + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py new file mode 100644 index 0000000000000..25ec9d6028627 --- /dev/null +++ b/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,43 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/collect_env.py b/collect_env.py index 76df97b099b1b..839d54172e775 100644 --- a/collect_env.py +++ b/collect_env.py @@ -66,6 +66,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } DEFAULT_PIP_PATTERNS = { @@ -79,6 +81,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 73944f4c14890..c35224218e91c 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,15 @@ #pragma once +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ + #define DEVICE_INLINE __forceinline__ __device__ + #define HOST_INLINE __forceinline__ __host__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 0000000000000..1842fab8b2cac --- /dev/null +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) + return true; + else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 0000000000000..1618a340ce10e --- /dev/null +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,154 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + return tensor.stride(idx); + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 0000000000000..085ee1290031f --- /dev/null +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 0000000000000..6146bdc1f08c6 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 0000000000000..4fcfcd311aa91 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,49 @@ +import enum +from typing import Dict, Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecializedMixedInput = enum_auto() + TmaWarpSpecializedPingpongMixedInput = enum_auto() + TmaWarpSpecializedCooperativeMixedInput = enum_auto() + + +VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMKernelScheduleTag: Dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + } + } diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 0000000000000..2ad914f8e9868 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,795 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + CUTE_INVALID_CONTROL_PATH( + "InterleavedNumericArrayConverter not implemented\n"); + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// TODO (LucasWilkinson): Implement +// for Array <= Array + +// .... + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + if constexpr (sizeof(PackedSrc) == 1) { + return static_cast(reinterpret_cast(source)); + } else if constexpr (sizeof(PackedSrc) == 2) { + return static_cast(reinterpret_cast(source)); + } else { + static_assert(sizeof(PackedSrc) == 4); + return reinterpret_cast(source); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_reg(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ops.h b/csrc/ops.h index 6094599901022..6bf0cff232528 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -83,6 +83,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k); +namespace machete { + +std::vector supported_schedules( + vllm::ScalarTypeTorchPtr const& btype); + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule); + +torch::Tensor prepack_B(torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype); + +}; // namespace machete + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, diff --git a/csrc/quantization/machete/Readme.md b/csrc/quantization/machete/Readme.md new file mode 100644 index 0000000000000..9ddf8da993b0e --- /dev/null +++ b/csrc/quantization/machete/Readme.md @@ -0,0 +1,45 @@ +# Machete (Mixed Precision Cutlass-Based GEMM) + +Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin. + +## Overview + +Machete effectively performs + +``` +scale_type = w_s.dtype +compute_type = a.dtype +out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a +``` + +Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and +`w_z` is the quantization zeropoints. + +> **_NOTE:_** `w_z` is added after the scales so we can +use FMA operations, but this means they must have the scales pre-applied if the +supplied zeropoints assume that they will be subtracted before the scales are +applied. + +## API + +The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: + +``` +from vllm import _custom_ops as ops + +... +W_q_packed = ops.machete_prepack_B(w_q, wtype) +output = ops.machete_gemm( + a, + b_q=W_q_packed, + b_type=wtype, + b_scales=w_s, + b_group_size=group_size +) +``` + +## Code Generation + +Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. + +New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py new file mode 100644 index 0000000000000..09a98a5dd1fd6 --- /dev/null +++ b/csrc/quantization/machete/generate.py @@ -0,0 +1,446 @@ +import itertools +import math +import os +import shutil +from collections.abc import Iterable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import jinja2 +# yapf conflicts with isort for this block +# yapf: disable +from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, VLLMDataType, + VLLMDataTypeNames, VLLMDataTypeTag, + VLLMKernelScheduleTag) + +# yapf: enable + +# +# Generator templating +# + +DISPATCH_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +using GemmDispatcher_ = GemmDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +{% for s in schedules %}extern torch::Tensor +impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); +{% endfor %} +template <> +torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.schedule) { + {%- for cond, s in heuristic %} + {%if cond is not none%}if ({{cond}}) + {%- else %}else + {%- endif %} + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + } + + {% for s in schedules %} + if (*args.schedule == "{{ gen_sch_name(s) }}") { + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); + } + {% endfor %} + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.schedule); +} + +template <> +std::vector GemmDispatcher_::supported_schedules() { + return { + {% for s in schedules -%} + "{{ gen_sch_name(s) }}"{{ ", + " if not loop.last }}{%- endfor %} + }; +} + +}; // namespace machete +""" + +IMPL_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +template +using Kernel = MacheteKernelTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Config, with_C, with_scales, with_zeropoints>; + +{% for sch in schedules %} +{% set schedule_name = gen_sch_name(sch) -%} +struct sch_{{schedule_name}} { + using TileShapeNM = Shape<{{ + to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; + using ClusterShape = Shape<{{ + to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; + // TODO: Reimplement + // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; + using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +torch::Tensor +impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { + bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), + with_zeropoints = args.zeros.has_value(); + + {% for s in specializations %} + if (with_C == {{s.with_C|lower}} + && with_zeropoints == {{s.with_zeropoints|lower}} + && with_scales == {{s.with_scales|lower}}) { + return run_impl>(args); + }{% endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "for the sake of compile times and binary size machete_mm(..) is " + " not implemented for with_C=", with_C, ", with_scales=", with_scales, + ", with_zeropoints=", with_zeropoints, + " (for {{type_name}}_sch_{{schedule_name}})"); +} +{% endfor %} + +}; // namespace machete +""" + +PREPACK_TEMPLATE = """ +#include "../machete_prepack_launcher.cuh" + +namespace machete { +using PrepackBDispatcher_ = PrepackBDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +using PrepackedLayoutB = PrepackedLayoutBTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; + +template <> +torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { + return prepack_impl(B); +} +}; // namespace machete +""" + +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput +TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative + + +@dataclass +class ScheduleConfig: + tile_shape_mn: Tuple[int, int] + cluster_shape_mnk: Tuple[int, int, int] + kernel_schedule: MixedInputKernelScheduleType + epilogue_schedule: EpilogueScheduleType + tile_scheduler: TileSchedulerType + + +@dataclass +class TypeConfig: + element_a: DataType + element_b: Union[DataType, VLLMDataType] + element_b_scale: DataType + element_b_zeropoint: DataType + element_d: DataType + accumulator: DataType + + +@dataclass +class Specialization: + with_C: bool + with_zeropoints: bool + with_scales: bool + + +@dataclass +class ImplConfig: + type_config: TypeConfig + schedule_configs: List[ScheduleConfig] + specializations: List[Specialization] + heuristic: List[Tuple[Optional[str], ScheduleConfig]] + + +def generate_schedule_name(schedule_config: ScheduleConfig) -> str: + tile_shape = ( + f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" + ) + cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}") + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ + .split("::")[-1] + epilogue_schedule = EpilogueScheduleTag[ + schedule_config.epilogue_schedule].split("::")[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ + .split("::")[-1] + + return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}") + + +# mostly unique shorter schedule_name +def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: + kernel_terse_names_replace = { + "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", + "TmaWarpSpecializedCooperative_": "TmaCoop_", + "StreamKScheduler": "streamK", + } + + schedule_name = generate_schedule_name(schedule_config) + for orig, terse in kernel_terse_names_replace.items(): + schedule_name = schedule_name.replace(orig, terse) + return schedule_name + + +# unique type_name +def generate_type_signature(kernel_type_config: TypeConfig): + element_a = VLLMDataTypeNames[kernel_type_config.element_a] + element_b = VLLMDataTypeNames[kernel_type_config.element_b] + element_d = VLLMDataTypeNames[kernel_type_config.element_d] + accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] + element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] + element_zeropoint = VLLMDataTypeNames[ + kernel_type_config.element_b_zeropoint] + + return (f"{element_a}{element_b}{element_d}" + f"{accumulator}{element_scale}{element_zeropoint}") + + +# non-unique shorter type_name +def generate_terse_type_signature(kernel_type_config: TypeConfig): + element_a = VLLMDataTypeNames[kernel_type_config.element_a] + element_b = VLLMDataTypeNames[kernel_type_config.element_b] + + return f"{element_a}{element_b}" + + +def is_power_of_two(n): + return (n != 0) and (n & (n - 1) == 0) + + +def to_cute_constant(value: List[int]): + + def _to_cute_constant(value: int): + if is_power_of_two(value): + return f"_{value}" + else: + return f"Int<{value}>" + + if isinstance(value, Iterable): + return [_to_cute_constant(value) for value in value] + else: + return _to_cute_constant(value) + + +template_globals = { + "DataTypeTag": VLLMDataTypeTag, + "KernelScheduleTag": VLLMKernelScheduleTag, + "EpilogueScheduleTag": EpilogueScheduleTag, + "TileSchedulerTag": TileSchedulerTag, + "to_cute_constant": to_cute_constant, + "gen_sch_name": generate_terse_schedule_name, +} + + +def create_template(template_str): + template = jinja2.Template(template_str) + template.globals.update(template_globals) + return template + + +mm_dispatch_template = create_template(DISPATCH_TEMPLATE) +mm_impl_template = create_template(IMPL_TEMPLATE) +prepack_dispatch_template = create_template(PREPACK_TEMPLATE) + + +def create_sources(impl_config: ImplConfig, num_impl_files=2): + sources = [] + + type_name = generate_type_signature(impl_config.type_config) + terse_type_name = generate_terse_type_signature(impl_config.type_config) + + sources.append(( + f"machete_mm_{terse_type_name}", + mm_dispatch_template.render(type_name=type_name, + type_config=impl_config.type_config, + schedules=impl_config.schedule_configs, + heuristic=impl_config.heuristic), + )) + + sources.append(( + f"machete_prepack_{terse_type_name}", + prepack_dispatch_template.render( + type_name=type_name, + type_config=impl_config.type_config, + ), + )) + + num_schedules = len(impl_config.schedule_configs) + schedules_per_file = math.ceil(num_schedules / num_impl_files) + for part, i in enumerate(range(0, num_schedules, schedules_per_file)): + file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + + sources.append(( + f"machete_mm_{terse_type_name}_impl_part{part}", + mm_impl_template.render( + type_name=type_name, + type_config=impl_config.type_config, + schedules=file_schedules, + specializations=impl_config.specializations, + ), + )) + return sources + + +def generate(): + # See csrc/quantization/machete/Readme.md, the Codegeneration for more info + # about how this works + SCRIPT_DIR = os.path.dirname(__file__) + + schedules = [ + ScheduleConfig( + tile_shape_mn=tile_shape_mn, + cluster_shape_mnk=cluster_shape_mnk, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + ) for tile_shape_mn, cluster_shape_mnk in ( + ((128, 16), (1, 1, 1)), + ((128, 32), (1, 1, 1)), + ((128, 64), (1, 1, 1)), + ((128, 128), (1, 1, 1)), + ) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, ) + for tile_scheduler in (TileSchedulerType.StreamK, ) + ] + + # For now we use the same heuristic for all types + default_heuristic = [ + ("M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + ("M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + ("M > 16", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + (None, + ScheduleConfig(tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK)) + ] + + impl_configs = [] + + GPTQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for element_a in (DataType.f16, DataType.bf16))) + + GPTQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=False, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(GPTQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + AWQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (DataType.u4, DataType.u8) + for element_a in (DataType.f16, DataType.bf16))) + + AWQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=True, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(AWQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + output_dir = os.path.join(SCRIPT_DIR, "generated") + + # Delete the "generated" directory if it exists + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + # Create the "generated" directory + os.makedirs(output_dir) + + # Render each group of configurations into separate files + for impl_config in impl_configs: + for filename, code in create_sources(impl_config): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") + + +if __name__ == "__main__": + generate() diff --git a/csrc/quantization/machete/machete_collective_builder.cuh b/csrc/quantization/machete/machete_collective_builder.cuh new file mode 100644 index 0000000000000..a74cf8b2dd455 --- /dev/null +++ b/csrc/quantization/machete/machete_collective_builder.cuh @@ -0,0 +1,33 @@ +#pragma once + +#include "cutlass_extensions/vllm_collective_builder.cuh" +#include "machete_mainloop.cuh" + +namespace cutlass::gemm::collective { +using namespace cute; + +struct MacheteKernelTag {}; + +template +struct VLLMCollectiveBuilder< + MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, + ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, + KernelScheduleType, + cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v)>> { + using CollectiveOp = machete::MacheteCollectiveMma< + ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, + AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType>; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/quantization/machete/machete_interleaving_utils.cuh b/csrc/quantization/machete/machete_interleaving_utils.cuh new file mode 100644 index 0000000000000..d397f87f19acb --- /dev/null +++ b/csrc/quantization/machete/machete_interleaving_utils.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace machete { + +using namespace cute; + +// get an interleaved block layout where each element consecutive element has a +// stride of bit_stride and the block width is blk_bit_width, +// examples: +// size_bits = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1 +// size_bits = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1) +// size_bits = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1) +// size_bits = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1) +template +CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() { + static_assert(blk_bit_width % bit_stride == 0); + static_assert(bit_stride % cute::sizeof_bits_v == 0); + + constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v; + + if constexpr (cute::sizeof_bits_v == bit_stride) { + // identity layout + return Layout>>{}; + } else { + constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v; + constexpr auto num_strides = elems_per_blk / elems_per_stride; + return Layout, Int>, + Stride, Int<1>>>{}; + } +} + +}; // namespace machete diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh new file mode 100644 index 0000000000000..3d574ad99efda --- /dev/null +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -0,0 +1,1473 @@ +// +// Based off of: +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// +#pragma once + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cutlass/detail/collective.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" + +namespace machete { + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cutlass::gemm::collective; +using namespace cutlass::gemm::collective::detail; + +template +struct MacheteCollectiveMma { + using Schedule = KernelScheduleType; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); + + public: + static constexpr bool ALayoutIsPrepacked = true; + + // Prepacked block shape (N is M in the transposed problem) + using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK; + // Prepacked blocks per dim for a single MMA tile + using PPBlocksPerTile_MK = decltype(make_shape( + size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), + size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); + + using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout; + + static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0, + "M in PPBlockShape_MK must evenly divide M TileShape_MNK"); + static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0, + "K in PPBlockShape_MK must evenly divide K TileShape_MNK"); + + using ArchTag = arch::Sm90; + using TileShape = TileShape_MNK; + using ClusterShape = ClusterShape_MNK; + using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>; + using StrideA = TagToStrideA_t; + using ElementB = ElementB_; + using StrideB = TagToStrideB_t; + using ElementAccumulator = ElementAccumulator_; + using ElementMma = ElementB; + using ElementATuple = + cute::conditional_t::value, + cute::tuple, ElementATuple_>; + + static constexpr cute::GMMA::Major GmmaMajorA = + gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // + using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; + using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; + + static constexpr bool IsANarrow = cutlass::sizeof_bits::value < + cutlass::sizeof_bits::value; + static_assert(IsANarrow, + "A must be the narrow one since its the one that flows through " + "registers."); + + public: + static constexpr int PipelineStages = + compute_stage_count_or_override_single_affine_transformed_input< + sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, + ElementZero, TileShape_MNK>(StageCountType{}); + + struct DispatchPolicy { + constexpr static int Stages = PipelineStages; + using ClusterShape = ClusterShape_MNK; + using Schedule = KernelScheduleType; + }; + + using GmemTiledCopyA = + decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = + decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + // ((T, V), (BlocksM, BlocksK), pipe) -> offset + using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutAtomARowMajor = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomScale = Layout< + Shape(SmemLayoutAtomARowMajor{})), cute::Int<1>>>; + + using SmemLayoutAtomB = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomB = void; + + // + // Validity checks + // + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + public: + // + // Type Aliases + // + using KernelSchedule = KernelScheduleType; + + // For cases where we can't have a void type, we can use this to allow the + // code to compile when the scale / zero is void. + using NonVoidElementScale = + cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = + cute::conditional_t, float, ElementZero>; + + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the + // code to compile when the scale is void. + using NonVoidStrideScale = + cute::conditional_t, + cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((cutlass::gemm::detail::is_k_major()), + "The transformed matrix (A) must be K-major."); + + static_assert((sizeof(ElementB) == 2) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element (matrix B) must be 2 bytes OR both " + "inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major " + "if B is scaled]."); + + static_assert(std::is_same_v, + "TiledMma::ValTypeC must be the same as ElementAccumulator."); + + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemCopyAtomScale = Copy_Atom; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = + cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), + shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, + "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutACopy = decltype(tile_to_shape( + SmemLayoutAtomARowMajor{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), + Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major + // only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, + layout::ColumnMajor> && + cute::is_same_v, + layout::RowMajor>; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc " + "for this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for + // every main loop iteration. We must also handle updating the pipeline + // transaction bytes on the fly. NOTE: Deleting this assertion without + // required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, + "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + // clang-format on + + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0))))); + + using ATensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + shape(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0)))), + PrepackedStrideA{})); + + using BTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(StrideB{}, int32_t(0)), StrideB{})); + using ScaleTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + using ZeroTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { + return make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_scale( + ScaleTensor tensor_scale = ScaleTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_zero( + ZeroTensor tensor_zero = ZeroTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) { + return make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } + + public: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + // clang-format on + + // + // section setup end + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types + // Device side kernel params + struct Params { + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A()); + using TMA_Scale = decltype(make_tma_copy_scale()); + using TMA_Zero = decltype(make_tma_copy_zero()); + using TMA_B = decltype(make_tma_copy_B()); + + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) { + return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride)); + }; + + typename Params::TMA_A tma_load_a; + typename Params::TMA_B tma_load_b; + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + tma_load_a = make_tma_copy_A( + make_logical_tensor(ptr_A, shape(layout), stride(layout))); + + tma_load_b = make_tma_copy_B( + make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + + if constexpr (ModeHasScales) { + tma_load_scale = make_tma_copy_scale(make_logical_tensor( + args.ptr_S, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + tma_load_zero = make_tma_copy_zero(make_logical_tensor( + args.ptr_Z, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + + return {tma_load_a, tma_load_b, tma_load_scale, + tma_load_zero, scale_k, args.group_size}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + // clang-format off + + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the + // contract Returned tuple must contain at least two elements, with the first + // two elements being: gA_mkl - The tma tensor, A after a local tile so it + // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local + // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be + // specified as needed by this collective. + // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the + // values within a prepacked block. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { + using X = Underscore; + auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL), + K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL); + + // (TILE_V,TILE_B,m,k,l) + auto make_gA_mkl = [&]() { + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); + return local_tile(mA_mkl, + make_shape(size<0>(layout), PPBlocksPerTile_MK{}), + make_coord(0, make_coord(_, _))); + }; + + // (TILE_N,TILE_K,n,k,l) + auto make_gB_nkl = [&]() { + Tensor mB_nkl = + mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); + return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gS_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gZ_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(), + make_gZ_mkl()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in load_init."); + } + } + + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + // clang-format on + + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // + // Perform a collective-scoped matrix multiply-accumulate + // Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for " + "RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset + auto constexpr smem_A = SmemLayoutA{}; + + // convert: + // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset + // to: + // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset + // which can be thought of as: + // (T, MMA, (MMA_M, MMA_K), pipe) -> offset + auto constexpr smem_A_mma_ = + make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A), + zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A)); + // flatten to: + // (T, MMA, MMA_M, MMA_K, pipe) -> offset + auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate fragments and descriptors + Tensor tCrA_load = make_tensor( + tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K) + Tensor tCrA_mma = make_fragment_like(tCrA_load); + + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + static constexpr int A_CPY_VEC = + decltype(max_common_vector(tCsA, tCrA_load)){}; + + static constexpr int COVERSION_WIDTH = + std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); + + auto load_A_to_registers = [&](int read_stage) { + copy(create_auto_vectorizing_copy(), + tCsA(_, _, _, read_stage), tCrA_load(_, _, _)); + }; + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = + partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info( + tiled_mma, partitioned_extra_info, warp_group_thread_idx); + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + auto convert_A = [&, a_vec = Int{}](int k_block, + int read_stage) { + load_extra_info_to_registers(partitioned_extra_info, + copy_partitions_extra_info, k_block, + read_stage); + transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info, + k_block); + }; + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + load_A_to_registers(read_stage); + convert_A(0, read_stage); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, smem_pipe_read.index()); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to + // overwrite the A registers for the first mma. + warpgroup_wait(); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, + // so we can release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } else { + convert_A(k_block + 1, read_stage); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, read_stage); + } + } + } + + warpgroup_fence_operand(accum); + } + + // Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it + ++smem_pipe_release; + } + } + + private: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required + template + CUTLASS_DEVICE void load_extra_info_to_registers( + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, + int read_stage) { + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), + tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), + tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } + } + + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. + // Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock( + TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, + int const k_block) { + auto in = tCrA_load(_, _, k_block); + auto out = tCrA_mma(_, _, k_block); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + convert_tensor(in, out, vec_A); + } else if constexpr (ModeHasScales) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto converted_inputs = + make_fragment_like(tCrA_mma)(_, _, k_block); + auto scales = tCrS(_, _, 0); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, vec_A); + // Apply scales and broadcast across inputs, store in converted_inputs + + // We need to cast to nv_bfloat16 for the multiply since + // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to + // float, which nvcc will not optimize to using vectorized fma + // instructions (i.e. hfma.bf16_v2) + if constexpr (std::is_same_v) { + cute::transform( + recast(converted_inputs), recast(scales), + recast(converted_inputs), cute::multiplies{}); + } else { + cute::transform(converted_inputs, scales, converted_inputs, + cute::multiplies{}); + } + + // Apply zeros if required + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto converted_zeros = make_fragment_like(tCrZ)(_, _, 0); + + convert_tensor(tCrZ(_, _, 0), converted_zeros); + if constexpr (std::is_same_v) { + cute::transform(recast(converted_inputs), + recast(converted_zeros), + recast(converted_inputs), cute::plus{}); + } else { + cute::transform(converted_inputs, converted_zeros, converted_inputs, + cute::plus{}); + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } else { + static_assert(cutlass::detail::dependent_false, + "No A data is loaded."); + } + } + + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts + // Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + // This is an element-wise conversion where we expect both tensors to have + // the same layout. As a result, we can cast as a cutlass array to use the + // fast numeric converters without worrying about indexing into the layout. + constexpr int N = cosize_v; + + // The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, + "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, + "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, + "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), + "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, + "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + using Converter = cutlass::InterleavedNumericArrayConverter< + IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = + reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = + reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } +}; + +} // namespace machete diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh new file mode 100644 index 0000000000000..046e6e5a53652 --- /dev/null +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -0,0 +1,237 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "machete_collective_builder.cuh" +#include "machete_prepacked_layout.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +// NOTE This kernel computes D = alpha * A * B + beta * C by computing +// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma +// instructions only support sourcing from registers for the left-hand +// operand, we want to upconvert/decompress the quantized operand in +// register. Since the primary use case we want to support is Y = XW^t where +// W is quantized, in this situation or right-hand operand is quantized so +// we compute the transpose to move it to the left-hand side. +template +struct MacheteKernelTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementC = cute::conditional_t; + using ElementZ = ZeroT; + using ElementS = ScaleT; + + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementCompute = AccumulatorT; // For Epilogue + + using BTypeTuple = cute::conditional_t< + with_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; + + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideS = cutlass::detail::TagToStrideA_t; + using StrideZ = StrideS; + + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::VLLMCollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + template + static Arguments create_arguments( + cudaStream_t stream, + ElementA const* A_ptr, // A is an MxK matrix + Layout const& layout_A, + ElementB const* B_ptr, // B is an KxN prepacked matrix + ElementD* D_ptr, // D is an MxN matrix + Layout const& layout_D, + ElementC const* C_ptr, // C is an MxN matrix + std::optional> const& layout_C, + ElementS const* S_ptr, // S is an scale_KxN matrix + std::optional> const& layout_S, + ElementZ const* Z_ptr, // Z is an scale_KxN matrix + std::optional> const& layout_Z, + ElementCompute alpha, ElementCompute beta, + std::optional maybe_group_size) { + static_assert(!with_zeropoints || with_scales); + + int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + + int const group_size = maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_C) { + TORCH_CHECK(C_ptr && layout_C); + } else { + TORCH_CHECK(!C_ptr, "C not supported"); + } + + if constexpr (with_scales) { + TORCH_CHECK(S_ptr && layout_S); + TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + } else { + TORCH_CHECK(!S_ptr, "Scales not supported"); + } + + if constexpr (with_zeropoints) { + TORCH_CHECK(Z_ptr && layout_Z); + TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); + TORCH_CHECK(layout_S && *layout_Z == *layout_S, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + } + + // Transpose A and D + // A doesn't need to be transposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + auto stride_Ct = stride_Dt; + if (layout_C) { + stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); + } + + MainloopArguments mainloop_arguments{}; + EpilogueArguments epilogue_arguments{ + {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + + if constexpr (with_scales && with_zeropoints) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_ptr, stride_S, group_size, Z_ptr}; + } else if constexpr (with_scales) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = MainloopArguments{ + B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + } else { + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; + } + + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; + + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } + + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } + + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + } +}; + +}; // namespace machete diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh new file mode 100644 index 0000000000000..e2604d4bed3e2 --- /dev/null +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -0,0 +1,95 @@ +#pragma once + +#include +#include + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +struct PyTorchArguments { + torch::Tensor const& A; + torch::Tensor const& B; + c10::optional const& scales; + c10::optional const& zeros; + c10::optional group_size; + c10::optional const& C; + c10::optional alpha; + c10::optional beta; + c10::optional schedule; +}; + +template +torch::Tensor run_impl(PyTorchArguments args) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); + + auto device = args.A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + using EleA = typename MacheteKernel::ElementA; + using EleB = typename MacheteKernel::ElementB; + using EleC = typename MacheteKernel::ElementC; + using EleD = typename MacheteKernel::ElementD; + using EleScale = typename MacheteKernel::ElementS; + using EleZero = typename MacheteKernel::ElementZ; + + using StrideA = typename MacheteKernel::StrideA; + using StrideC = typename MacheteKernel::StrideC; + using StrideD = typename MacheteKernel::StrideD; + using StrideS = typename MacheteKernel::StrideS; + using StrideZ = typename MacheteKernel::StrideZ; + + int M = args.A.size(0); + int N = args.B.size(1); + int K = args.A.size(1); + + // Allocate output + torch::Tensor D = + torch::empty({M, N}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + + auto const &A = args.A, &B = args.B; + auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_C = maybe_make_cute_layout(C, "C"); + auto layout_S = maybe_make_cute_layout(scales, "scales"); + auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); + + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); + auto S_ptr = + static_cast(scales ? scales->const_data_ptr() : nullptr); + auto Z_ptr = + static_cast(zeros ? zeros->const_data_ptr() : nullptr); + + auto arguments = MacheteKernel::create_arguments( + stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, + layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), + args.group_size.value_or(K)); + TORCH_CHECK(MacheteKernel::can_implement(arguments), + "Machete kernel cannot be run with these arguments"); + + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); + torch::Tensor workspace = torch::empty( + workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); + + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); + + return D; +}; + +template +struct GemmDispatcher { + static torch::Tensor dispatch(PyTorchArguments args); + static std::vector supported_schedules(); +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh new file mode 100644 index 0000000000000..8e02104587d17 --- /dev/null +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -0,0 +1,62 @@ +#pragma once + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +static __global__ void prepack_B_kernel(BInTensor B_in, + BTiledOutTensor B_tiled_out) { + auto tB_in = local_tile(B_in, TileShapeNKL{}, + make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); + auto tB_out = B_tiled_out(make_coord(_, _), + make_coord(blockIdx.x, blockIdx.y), blockIdx.z); + + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout, Stride<_32, _1>>{}, + Layout>{}); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + Tensor thr_tile_S = thr_copy.partition_S(tB_in); + Tensor thr_tile_D = thr_copy.partition_D(tB_out); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition + auto fragment = make_tensor(shape(thr_tile_D)); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(tiled_copy, thr_tile_S, fragment); + copy(Copy_Atom{}, fragment, thr_tile_D); +} + +template +static void prepack_B(cudaStream_t stream, + typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, + typename PrepackedLayoutB::ElementB* B_out_ptr) { + using TileShapeNKL = + decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); + auto ilvd_NKbNbKL_to_offset = + PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout)); + + TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); + + auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); + auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + + auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); + auto B_tiled_out = + make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); + + prepack_B_kernel + <<>>(B_in, B_tiled_out); +} + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh new file mode 100644 index 0000000000000..686dd68bd52bb --- /dev/null +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -0,0 +1,71 @@ +#pragma once + +#include "machete_prepack_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +torch::Tensor prepack_impl(torch::Tensor const B) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; + + auto device = B.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = + (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; + + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); + + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on + + // Allocate output + torch::Tensor D = torch::empty_like(B); + + prepack_B(stream, B_ptr, layout_Bt, + static_cast(D.mutable_data_ptr())); + + return D; +}; + +template +struct PrepackBDispatcher { + static torch::Tensor dispatch(torch::Tensor B); +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh new file mode 100644 index 0000000000000..78e2cc5eec7d8 --- /dev/null +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -0,0 +1,220 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "machete_collective_builder.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +struct IlvBlkLayoutAuto {}; + +// This defines a prepacked layout for the B matrix, where the matrix is broken +// up into PPBlockShape_NK blocks. The data within each block is then compactly +// stored in memory such that when performing a TiledMMA operation with the same +// shape as prepacked block, all the data for a given thread is contiguous in +// memory. This allows us to use wider shared memory loads when loading B from +// shared memory. The values within a thread are also potentially interlaeved +// inorder to allow for more efficient upconverting. +// +// The contract here is that the `TiledMma` determined below matches the one +// ultimately used in the kernel. (this is also why the other element types are +// required along with the kernel schedule) +template +// clang-format on +struct PrepackedLayoutBTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementMma = MmaType; + + // Only use interleaved layouts for subbyte weights, prmt instructions makes + // non-interleaved layouts for 8bit+ weights efficient enough we don't need + // iterleaved layouts + using IlvdBlkLayout = std::conditional_t< + std::is_same_v, + std::conditional_t <= 4, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, + IlvBlkLayout_>; + + // TODO (LucasWilkinson): compare the performance for other sizes + // Prepacked block shape, smallest layout atom for loading into registers + // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits + using PPBlockShape_NK = Shape<_128, _64>; + + // Create the shape of the tile anticipated to be used by the GEMM kernel, + // when the kernel executes we will compute `Ct = Bt * At` since the + // quantized weights (B), must be the lhs operand so the flow through + // registers. + // The _128 here doesn't actually impact the shape of the stored tile directly + // but may impact the op selected by rs_op_selector + using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, + size<1>(PPBlockShape_NK{}))); + + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + // Prepacked block, (athrid, val) -> (N,K) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { + return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); + } + + // Prepacked block, (N,K) -> (athrid, val) + // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { + return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { + // Return iterleaved layout + return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { + auto layout_no_interleave = + make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + + if constexpr (std::is_same_v) { + return layout_no_interleave; + } else { + // interleave by transforming FrgV into interleaved blocks where each + // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is + // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) + // if FrgV is {A, B, C, D, E, F, G, H} + // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} + auto frgV = get<1, 0>(layout_no_interleave); + auto ilvdBlk = IlvdBlkLayout{}; + static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + auto ilvd_FrgV = make_layout( + make_shape(shape(ilvdBlk), Int{}), + make_stride(stride(ilvdBlk), size(ilvdBlk))); + + // Return iterleaved layout + return make_layout( + get<0>(layout_no_interleave), + make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); + } + } + + // Prepacked block, (M,K) -> (storage_offset) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { + // do (M,K) -> (athrid, val) -> (storage_idx) + return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); + } + + // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_TV_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) + // => ((athrid, val), (BlocksN, BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, + // BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { + auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), + make_layout(size<1>(PPBlockShape_NK{}))); + + // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) + auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); + return tiled_A.compose(ppblock_TV_to_NK(), _); + } + + // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) + template + CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { + auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); + return blocked_product(ppblock_NK_to_TV(), + make_layout(shape<1>(TVbNbK_to_NKL_layout))); + } +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu new file mode 100644 index 0000000000000..ef36a490c3c50 --- /dev/null +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -0,0 +1,79 @@ +#include "machete_mm_launcher.cuh" +#include "machete_prepack_launcher.cuh" +#include "core/scalar_type.hpp" + +namespace machete { + +using namespace vllm; + +// +// Utils (type dispatching) +// + +template +static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { + if (type == vllm::kU4) { + return fn(cutlass::uint4b_t{}); + } else if (type == vllm::kU8) { + return fn(cutlass::uint8_t{}); + } else if (type == vllm::kU4B8) { + return fn(cutlass::vllm_uint4b8_t{}); + } else if (type == vllm::kU8B128) { + return fn(cutlass::vllm_uint8b128_t{}); + } else { + TORCH_CHECK(false, "Unsupported type ", type.str()); + } +} + +#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ + AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) + +// +// Interface +// + +std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return GemmDispatcher::supported_schedules(); + }); +} + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule) { + auto args = PyTorchArguments{.A = A, + .B = B, + .scales = scales, + .zeros = zeros, + .group_size = group_size, + .C = C, + .alpha = alpha, + .beta = beta, + .schedule = schedule}; + + return scalar_type_dispatch(*btype, [&](auto BType) { + return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( + A.scalar_type(), "machete_gemm", [&] { + using ComputeType = equivalent_cutlass_type_t; + return GemmDispatcher::dispatch(args); + }); + }); +} + +torch::Tensor prepack_B(torch::Tensor const& B, + ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return PrepackBDispatcher::dispatch(B); + }); +} + +}; // namespace machete diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e26c2e28f2ecd..6d1f53b75f4e2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -133,6 +133,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. + ops.def("machete_supported_schedules", &machete::supported_schedules); + ops.def( + "machete_gemm(Tensor A, Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype," + " Tensor? scales, Tensor? zeros, int? group_size," + " Tensor? C, float? alpha, float? beta, str? schedule)" + "-> Tensor"); + ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); + ops.def( + "machete_prepack_B(Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype)" + "-> Tensor"); + ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 9a5964ec65b99..6a8d99635b8f0 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1 sphinx-copybutton==0.5.2 myst-parser==2.0.0 sphinx-argparse==0.4.0 +msgspec # packages to install to build the documentation pydantic diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index d8f27c4328a58..b67e0410f7441 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -70,7 +70,7 @@ vLLM OpenVINO backend uses the following environment variables to control behavi - ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform. -- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. +- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `` To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``) @@ -91,5 +91,3 @@ Limitations - Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration. - Tensor and pipeline parallelism are not currently enabled in vLLM integration. - -- Speculative sampling is not tested within vLLM integration. diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md index 2c7a7caa1bd7c..96d1f96bfa144 100644 --- a/examples/production_monitoring/Otel.md +++ b/examples/production_monitoring/Otel.md @@ -3,10 +3,10 @@ 1. Install OpenTelemetry packages: ``` pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai + 'opentelemetry-sdk>=1.26.0,<1.27.0' \ + 'opentelemetry-api>=1.26.0,<1.27.0' \ + 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0' ``` 1. Start Jaeger in a docker container: diff --git a/requirements-build.txt b/requirements-build.txt index bea55d930ab25..3f08f5d67b6da 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -5,3 +5,4 @@ packaging setuptools>=49.4.0 torch==2.4.0 wheel +jinja2 diff --git a/requirements-test.txt b/requirements-test.txt index 95909d37e2c94..cdbc3e50cc9ec 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -29,4 +29,5 @@ matplotlib # required for qwen-vl test aiohttp # quantization -bitsandbytes==0.42.0 \ No newline at end of file +bitsandbytes==0.42.0 +buildkite-test-collector==0.1.8 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 554b7f4d3bbfb..2406b8c67361b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,9 @@ from vllm.config import TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger @@ -92,6 +94,21 @@ def init_test_http_connection(): global_http_connection.reuse_client = False +@pytest.fixture +def dist_init(): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup() + + def cleanup(): destroy_model_parallel() destroy_distributed_environment() diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5fb8ec06cfa03..c2226870c2e83 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): assert new_block[0].block_id == last_block_id + # Test case for cache mertics + @staticmethod + def test_metric(): + block_size = 16 + allocator = PrefixCachingBlockAllocator(num_blocks=4, + block_size=block_size) + # Test when no query (0/0) + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + token_ids = list(range(block_size)) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 0/1 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 1/2 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.5 + + # Test more than one block + for _ in range(2, 1005): + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + assert allocator.get_prefix_cache_hit_rate() > 0.99 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py new file mode 100644 index 0000000000000..dadf594409535 --- /dev/null +++ b/tests/kernels/test_machete_gemm.py @@ -0,0 +1,272 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_gemm.py`. +""" + +import math +from typing import Optional, Tuple + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (13, 8192, 4096), + (26, 4096, 8192), + (1, 4096, 4096), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (64, 4096, 4096), +] + +ACT_TYPES = [torch.float16, torch.bfloat16] +WTYPE_ZEROPOINTS = [ + # GPTQ style + (scalar_types.uint4b8, False), + (scalar_types.uint8b128, False), + # AWQ style + (scalar_types.uint4, True), + (scalar_types.uint8, True), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + + +def rand_data(shape, dtype=torch.float16): + return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def machete_quantize_and_pack(w: torch.Tensor, + wtype: ScalarType, + group_size: int, + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + w_q_machete = ops.machete_prepack_B(w_q, wtype) + + return w_ref, w_q_machete, w_s, w_zp + + +def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, + wtype: ScalarType, group_size: int, + zero_points: bool): + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + output = ops.machete_gemm( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_all_schedules(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + print(f"MNK = {m} {n} {k}") + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + w = rand_data((k, n), atype) + + w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( + w, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + for schedule in ops.machete_supported_schedules(wtype): + output = ops.machete_gemm( + a, + b_q=w_q_machete, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + schedule=schedule, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ + f"Schedule failed {schedule}" + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_heuristic(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + b = rand_data((k, n), atype) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + m, n, k = 512, 4096, 4096 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + print(f"MNK = {m} {n} {k}, device = {device}") + + a = rand_data((m, k), torch.float16).to(device) + b = rand_data((k, n), torch.float16).to(device) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + whole_a = rand_data((big_m, big_k), torch.float16) + whole_b = rand_data((big_k, big_n), torch.float16) + + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_gemm(**self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index 5971179f01211..196cd88e039a1 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -20,7 +21,7 @@ MODELS = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), @@ -39,22 +40,36 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_models( + num_gpus_available, vllm_runner, example_prompts, model, dtype: str, max_tokens: int, num_logprobs: int, + tp_size: int, ) -> None: + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + original_model, gguf_model = model + tokenizer = AutoTokenizer.from_pretrained(original_model) + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + # Run unquantized model. with vllm_runner(model_name=original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as original_model: + tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) @@ -63,8 +78,7 @@ def test_models( with vllm_runner(model_name=gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as gguf_model: + tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) diff --git a/tests/models/test_intern_vit.py b/tests/models/test_intern_vit.py new file mode 100644 index 0000000000000..e980446ff3570 --- /dev/null +++ b/tests/models/test_intern_vit.py @@ -0,0 +1,80 @@ +from typing import Optional + +import pytest +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from transformers import AutoConfig, AutoModel, CLIPImageProcessor + +from vllm.model_executor.models.intern_vit import InternVisionModel + +from ..conftest import _ImageAssets, cleanup + +pytestmark = pytest.mark.vlm + +# we use snapshot_download to prevent conflicts between +# dynamic_module and trust_remote_code for hf_runner +DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] +models = [ + snapshot_download("OpenGVLab/InternViT-300M-448px", + allow_patterns=DOWNLOAD_PATTERN), + snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5", + allow_patterns=DOWNLOAD_PATTERN), +] + + +def run_intern_vit_test( + image_assets: _ImageAssets, + model: str, + *, + dtype: str, + distributed_executor_backend: Optional[str] = None, +): + img_processor = CLIPImageProcessor.from_pretrained(model) + images = [asset.pil_image for asset in image_assets] + pixel_values = [ + img_processor(images, return_tensors='pt').pixel_values.to(dtype) + for images in images + ] + + config = AutoConfig.from_pretrained(model, trust_remote_code=True) + if not getattr(config, "norm_type", None): + config.norm_type = "rms_norm" + + hf_model = AutoModel.from_pretrained(model, + torch_dtype=dtype, + trust_remote_code=True).to("cuda") + hf_outputs_per_image = [ + hf_model(pixel_value.to("cuda")).last_hidden_state + for pixel_value in pixel_values + ] + + vllm_model = InternVisionModel(config) + vllm_model.load_weights(hf_model.state_dict().items()) + + del hf_model + cleanup() + + vllm_model = vllm_model.to("cuda", dtype) + vllm_outputs_per_image = [ + vllm_model(pixel_values=pixel_value.to("cuda")) + for pixel_value in pixel_values + ] + del vllm_model + cleanup() + + cos_similar = nn.CosineSimilarity(dim=-1) + for vllm_output, hf_output in zip(vllm_outputs_per_image, + hf_outputs_per_image): + assert cos_similar(vllm_output, hf_output).mean() > 0.99 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [torch.half]) +@torch.inference_mode() +def test_models(dist_init, image_assets, model, dtype: str) -> None: + run_intern_vit_test( + image_assets, + model, + dtype=dtype, + ) diff --git a/tests/multi_step/__init__.py b/tests/multi_step/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness.py new file mode 100644 index 0000000000000..bc14311c66424 --- /dev/null +++ b/tests/multi_step/test_correctness.py @@ -0,0 +1,85 @@ +# Test the AsyncLLMEngine with multi-step-decoding + +from typing import List + +import pytest + +from ..utils import RemoteOpenAIServer + +MODELS = [ + "JackFram/llama-160m", +] +NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps +NUM_PROMPTS = [10] + +DEFAULT_SERVER_ARGS: List[str] = [ + "--disable-log-requests", + "--use-v2-block-manager", + "--worker-use-ray", + "--gpu-memory-utilization", + "0.85", + "--swap-space", + "16", +] + + +async def completions_with_server_args(prompts: List[str], model_name: str, + server_cli_args: List[str]): + + outputs = None + with RemoteOpenAIServer(model_name, server_cli_args) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5) + assert outputs is not None + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize(("tp_size, pp_size"), [ + (1, 1), + (2, 2), +]) +@pytest.mark.parametrize("eager_mode", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.asyncio +async def test_multi_step(example_prompts, model: str, tp_size: int, + pp_size: int, eager_mode: int, + num_scheduler_steps: int, num_prompts: int): + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] + ms_server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + if eager_mode: + ms_server_args.append("--enforce-eager") + + distributed_args = [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + ] + + ref_completions = await completions_with_server_args( + prompts, model, server_args + distributed_args) + test_completions = await completions_with_server_args( + prompts, model, ms_server_args + distributed_args) + + def get_text_generations(completions): + return [x.text for x in completions.choices] + + ref_generations = get_text_generations(ref_completions) + test_generations = get_text_generations(test_completions) + assert ref_generations == test_generations diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9821dbd066a59..2dff84b812b89 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -34,6 +34,9 @@ def test_block_allocator( assert (first_block == second_block) assert (second_block.ref_count == 2) + # Check metric: 1 hit of 2 queries + assert block_allocator.get_prefix_cache_hit_rate() == 0.5 + # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) @@ -48,6 +51,10 @@ def test_block_allocator( assert (first_block == second_block) assert (first_block.block_hash == block_hash) + # Allocate one more time to get 3/4 hit rate for easy checking + block_allocator.allocate(block_hash, 0) + assert block_allocator.get_prefix_cache_hit_rate() == 0.75 + @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index d0f91a63b2d6a..a701f482b4ffb 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ensure_all_accepted=ensure_all_accepted) -def run_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - temperature: float, - seeded: bool, - print_tokens: bool = False, - ensure_all_accepted: bool = False): +def run_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False, + expected_acceptance_rate: Optional[float] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero (or when temperature is > 0 and seeded). @@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + print(f'{acceptance_rate=}') + if ensure_all_accepted: assert acceptance_rate == 1.0 + + if expected_acceptance_rate is not None: + assert acceptance_rate >= expected_acceptance_rate - 1e-2 diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 25067e7a4262c..c72e4595fd335 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + }, +]) +@pytest.mark.parametrize("output_len", [2048]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify acceptance rate with different batch size and large output + length.""" + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=0.0, + seeded=True, + force_output_len=True, + expected_acceptance_rate=0.48) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2126fafb2323b..1e7f560fc68cc 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -5,11 +5,13 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from vllm.worker.multi_step_model_runner import StatefulModelInput class MockAttentionBackend(AttentionBackend): @@ -28,7 +30,11 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: @staticmethod def get_builder_cls() -> Type["AttentionMetadataBuilder"]: - raise AttentionMetadataBuilder + return AttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState @staticmethod def get_kv_cache_shape( @@ -154,3 +160,79 @@ def test_embedding_model_runner_input(): None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None + + +def test_multi_step_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + frozen_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + is_last_step=True, + is_first_multi_step=False, + current_step=4, + last_sampled_token_ids=torch.ones((10, 1)), + is_multi_step=True, + num_queries=8, + num_seqs=5, + cached_outputs=[], + ) + + assert isinstance(model_input, StatefulModelInput) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + + receieved_frozen_input = received_model_input.frozen_model_input + + # Check that received copy has correct values. + assert isinstance(received_model_input, StatefulModelInput) + assert receieved_frozen_input.input_tokens is not None + assert (receieved_frozen_input.input_tokens == + frozen_model_input.input_tokens).all() + assert receieved_frozen_input.input_positions is not None + assert (receieved_frozen_input.input_positions == + frozen_model_input.input_positions).all() + assert receieved_frozen_input.multi_modal_kwargs is None + assert (frozen_model_input.multi_modal_kwargs == + frozen_model_input.multi_modal_kwargs) + assert receieved_frozen_input.lora_requests is None + assert (receieved_frozen_input.lora_requests == + frozen_model_input.lora_requests) + assert receieved_frozen_input.lora_mapping is None + assert ( + receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) + for field in dataclasses.fields(AttentionMetadata): + assert getattr(receieved_frozen_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (receieved_frozen_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert receieved_frozen_input.sampling_metadata.seq_groups is None + + # check non frozen fields + assert received_model_input.is_last_step == model_input.is_last_step + assert (received_model_input.is_first_multi_step == + model_input.is_first_multi_step) + assert received_model_input.current_step == model_input.current_step + assert (received_model_input.last_sampled_token_ids == + model_input.last_sampled_token_ids).all() + assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1f0a111a53bcc..b89a90ef0f70c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -329,6 +329,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, num_bits, size_m, size_n, size_k) +# machete +def machete_supported_schedules(b_type: ScalarType) -> List[str]: + return torch.ops._C.machete_supported_schedules(b_type) + + +def machete_gemm( + a: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros, + b_group_size, c, alpha, beta, schedule) + + +def machete_prepack_B(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, b_type) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 4643d316d48b7..2cd4ad3e00135 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,7 +1,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionState, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -12,5 +12,6 @@ "AttentionType", "AttentionMetadataBuilder", "Attention", + "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 23c7830cd6264..ccfc6b254c1e7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, fields from enum import Enum, auto from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, @@ -7,7 +8,9 @@ import torch if TYPE_CHECKING: - from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) class AttentionType(Enum): @@ -34,6 +37,11 @@ def get_impl_cls() -> Type["AttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -126,6 +134,47 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], + attn_metadata: T) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 907b45393eeb5..d84a40890ebbd 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -5,7 +5,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -98,6 +99,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: return BlocksparseFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f146285bfc9e2..30ce715d5d05a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,7 +9,8 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -142,6 +143,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3022fa70e2ca7..2aa3bd79e4a64 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,14 +1,19 @@ +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper import vllm.attention.backends.flash_attn # noqa + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -16,7 +21,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionState, AttentionType) from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) @@ -46,6 +51,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -75,6 +84,160 @@ def get_supported_head_sizes() -> List[int]: return [64, 128, 256] +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + @dataclass class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index bac30aec24826..64d60e4e47e48 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -8,6 +8,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -28,6 +29,10 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]: def get_metadata_cls() -> Type["IpexAttnMetadata"]: return IpexAttnMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py index 0f21b50ad4dc7..7992c70f52659 100644 --- a/vllm/attention/backends/openvino.py +++ b/vllm/attention/backends/openvino.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Type import openvino as ov import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from vllm.attention.backends.utils import CommonAttentionState class OpenVINOAttentionBackend(AttentionBackend): @@ -24,6 +25,10 @@ def get_impl_cls(): def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": return OpenVINOAttentionMetadata(*args, **kwargs) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 4ecf698c8d514..ac03b6d8b1ead 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -6,6 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState class PallasAttentionBackend(AttentionBackend): @@ -18,6 +19,10 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: def get_metadata_cls() -> Type["PallasMetadata"]: return PallasMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e305679231d02..b0f4d0530b7f0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -33,6 +34,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: return ROCmFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index b83c673f0165e..8a1f8f2930c84 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -8,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu @@ -34,6 +35,10 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return TorchSDPAMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e6b5f820c5fa0..0375d3488eb15 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,12 +1,17 @@ """Attention backend utils""" -from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union import numpy as np import torch -from vllm.attention import AttentionMetadata, AttentionMetadataBuilder +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) from vllm.utils import async_tensor_h2d, make_tensor_with_pad +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " @@ -269,3 +274,69 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables=block_tables, use_cuda_graph=use_captured_graph, ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: + return { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + + def prepare_graph_input_buffers(self, input_buffers, + attn_metadata) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + + def begin_forward(self, model_input) -> None: + return diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 7e36509bff864..e073d616bf01d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import CommonMetadataBuilder +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -37,6 +38,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["XFormersMetadataBuilder"]: return XFormersMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/config.py b/vllm/config.py index a5a9984a0114a..0d5d098bc8858 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform -from vllm.tracing import is_otel_installed +from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, @@ -1721,9 +1721,11 @@ class ObservabilityConfig: collect_model_execute_time: bool = False def __post_init__(self): - if not is_otel_installed() and self.otlp_traces_endpoint is not None: - raise ValueError("OpenTelemetry packages must be installed before " - "configuring 'otlp_traces_endpoint'") + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") if ((self.collect_model_forward_time or self.collect_model_execute_time) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 1e808e21b72e5..eb190adfbe802 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -282,6 +283,58 @@ def ids(self) -> List[int]: return self._block_ids +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5287cd9c1bfb3..c6330df2a485a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -323,6 +323,11 @@ def get_common_computed_block_ids( def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index ab39832bc1f6e..f26bc761c9967 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -186,6 +186,11 @@ def get_num_blocks_touched(self, num_lookahead_slots: int = 0) -> int: pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block: There is at most one null block per allocator. """ pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 14a62c2e7190e..1643fd69c58ab 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e145eeba2d66e..432a6651ab07a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,9 +1,8 @@ """Token blocks.""" - from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple -from vllm.core.block.common import (CopyOnWriteTracker, +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, @@ -107,6 +106,8 @@ def __init__( self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.metric_data = CacheMetricData() + # Implements Block.Factory. def _create_block( self, @@ -155,9 +156,11 @@ def allocate_immutable_block(self, cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: + self.metric_data.query(hit=True) block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block + self.metric_data.query(hit=False) self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int: def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None if block.content_hash in self._cached_blocks: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index ad26d3c516ff0..0af04399a4b31 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.common import CacheMetricData from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool: def update_hash(self, block_hash: int, block: PhysicalTokenBlock): pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class CachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -85,6 +91,8 @@ def __init__(self, self.default_hash_ctr = count() + self.cache_metric_data = CacheMetricData() + def allocate_block(self, block_hash: int, num_hashed_tokens: int) -> PhysicalTokenBlock: if self.current_num_blocks == self.num_blocks: @@ -105,15 +113,17 @@ def allocate(self, num_hashed_tokens: int = 0) -> PhysicalTokenBlock: if block_hash is None: block_hash = next(self.default_hash_ctr) + if block_hash in self.evictor: assert block_hash not in self.cached_blocks block = self.evictor.remove(block_hash) assert block.ref_count == 0 self.cached_blocks[block_hash] = block - block.ref_count += 1 - assert block.block_hash == block_hash - return block - if block_hash not in self.cached_blocks: + + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) self.cached_blocks[block_hash] = self.allocate_block( block_hash, num_hashed_tokens) block = self.cached_blocks[block_hash] @@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): del self.cached_blocks[old_hash] self.cached_blocks[block_hash] = block + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + class UncachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): raise NotImplementedError( "Invalid codepath for uncached block allocator.") + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" @@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b48ea1b19b82a..b7d9451f18067 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7ceb..3d864a73f91d0 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -2,6 +2,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class EmbeddingModelBlockSpaceManager(BlockSpaceManager): @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self, def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index 5b1a208b7c866..0b943e6e65f1c 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -85,19 +85,21 @@ def evict(self) -> Tuple[int, int]: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block_id = next(iter(self.free_table.keys())) + evicted_block, evicted_block_id = None, None # The blocks with the lowest timestamps should be placed consecutively # at the start of OrderedDict. Loop through all these blocks to # find the one with maximum number of hashed tokens. for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue if evicted_block.last_accessed < block.last_accessed: break - if (evicted_block.last_accessed == block.last_accessed and - evicted_block.num_hashed_tokens < block.num_hashed_tokens): - evicted_block = block - evicted_block_id = _id + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + assert evicted_block is not None + assert evicted_block_id is not None self.free_table.pop(evicted_block_id) return evicted_block_id, evicted_block.content_hash @@ -110,7 +112,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed - self.free_table.move_to_end(block_id) def remove(self, block_id: int): if block_id not in self.free_table: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b8..becd0d2e7f849 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -5,6 +5,7 @@ from typing import Tuple from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class AllocStatus(enum.Enum): @@ -116,3 +117,8 @@ def get_common_computed_block_ids( @abstractmethod def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 802359d2283f7..3b716e32032c1 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,7 +14,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import PyObjectCache +from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -447,6 +447,9 @@ def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8fca2cc049958..7f45c3d06375a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, Union) +import torch + import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -113,7 +115,7 @@ class EngineArgs: fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: str = 'auto' + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 @@ -662,8 +664,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--disable-logprobs-during-spec-decoding', - type=bool, + action=StoreBoolean, default=EngineArgs.disable_logprobs_during_spec_decoding, + nargs="?", + const="True", help='If set to True, token log probabilities are not returned ' 'during speculative decoding. If set to False, log probabilities ' 'are returned according to the settings in SamplingParams. If ' @@ -851,6 +855,12 @@ def create_engine_config(self, ) -> EngineConfig: "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) + if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: + self.use_v2_block_manager = True + logger.warning( + "Enabled BlockSpaceManagerV2 because it is " + "required for multi-step (--num-scheduler-steps > 1)") + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -879,7 +889,6 @@ def create_engine_config(self, ) -> EngineConfig: ) if self.num_scheduler_steps > 1: - raise NotImplementedError("Multi-step is not yet supported.") if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index dced804fccca9..6385d3ca2297e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,9 +1,11 @@ import asyncio import time +from dataclasses import dataclass from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +import torch from transformers import PreTrainedTokenizer from typing_extensions import assert_never @@ -27,7 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -249,9 +252,25 @@ def has_new_requests(self): return not self._new_requests.empty() +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = \ + self.parallel_config.pipeline_parallel_size + self.cached_scheduler_outputs = [ + SchedulerOutputState() for _ in range(pipeline_parallel_size) + ] + async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: @@ -264,13 +283,39 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -279,15 +324,35 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -297,6 +362,60 @@ async def step_async( return request_outputs + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fcf45a38b9425..36cb6ce795f3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -47,7 +47,7 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import Counter, Device from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1390,6 +1390,13 @@ def _get_stats( for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1498,6 +1505,9 @@ def _get_stats( # KV Cache Usage in % gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 1071786c27cd6..74277cae7c8ef 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -71,6 +71,17 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -351,7 +362,13 @@ def log(self, stats: Stats) -> None: stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) if self.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( @@ -423,6 +440,10 @@ def _log_prometheus(self, stats: Stats) -> None: stats.gpu_cache_usage_sys) self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 7449aafc5aecb..1eccb23593408 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -32,6 +32,9 @@ class Stats: # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 55976f430254c..af426e31591f2 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -62,6 +62,18 @@ def _get_worker_kwargs( observability_config=self.observability_config, ) + def _get_worker_module_and_class(self) -> Tuple[str, str]: + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" + elif self.speculative_config: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + return (worker_module_name, worker_class_name) + def _get_create_worker_kwargs( self, local_rank: int = 0, @@ -69,13 +81,12 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") - else: - worker_kwargs.update( - worker_module_name="vllm.spec_decode.spec_decode_worker", - worker_class_name="create_spec_worker") + + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() + worker_kwargs.update(worker_module_name=worker_module_name, + worker_class_name=worker_class_name) + return worker_kwargs def _create_worker(self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 3a08ab4dbfd44..bddb95210dbc9 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,12 +91,8 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs def _get_worker_wrapper_args(self) -> Dict[str, Any]: - if self.speculative_config is not None: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() return dict( worker_module_name=worker_module_name, @@ -104,6 +100,10 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: trust_remote_code=self.model_config.trust_remote_code, ) + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if (self.parallel_config.tensor_parallel_size == 1 @@ -228,8 +228,12 @@ def sort_by_driver_then_worker_ip(worker): "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) + all_args=self._get_env_vars_to_be_updated()) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 687f938cfb937..45c8a3db04e61 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,16 +1,16 @@ -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig) + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import make_async -from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -30,6 +30,7 @@ def __init__( lora_config: Optional[LoRAConfig], prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], + observability_config: Optional[ObservabilityConfig], ) -> None: assert device_config.device_type == "xpu" assert (not speculative_config @@ -46,32 +47,23 @@ def __init__( self.device_config = device_config self.prompt_adapter_config = prompt_adapter_config self.speculative_config = None + self.observability_config = observability_config # Instantiate the worker and load the model to GPU. self._init_executor() - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - if self.speculative_config is None: - worker_module_name = "vllm.worker.xpu_worker" - worker_class_name = "XPUWorker" - else: + def _get_worker_module_and_class(self) -> Tuple[str, str]: + if self.speculative_config is not None: raise NotImplementedError( "XPU does not support speculative decoding") - - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) - wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return wrapper.worker + else: + worker_module_name = "vllm.worker.xpu_worker" + worker_class_name = "XPUWorker" + return (worker_module_name, worker_class_name) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: output = self.driver_worker.execute_model(execute_model_req) return output diff --git a/vllm/logger.py b/vllm/logger.py index 3c6bf0803a624..77dddbfb60965 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -43,6 +43,7 @@ }, }, "version": 1, + "disable_existing_loggers": False } diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c834299961..d666fc293757b 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -10,8 +10,10 @@ import torch from vllm.triton_utils import HAS_TRITON +from vllm.utils import is_xpu -if HAS_TRITON: +# FIXME: xpu path doesn't support torch.library.custom_op +if HAS_TRITON and not is_xpu(): from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 51f3ef5dbb325..49247cd5de42a 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -30,7 +30,9 @@ def forward_hip(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) def forward_xpu(self, *args, **kwargs): - raise NotImplementedError + # By default, we assume that XPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) def forward_cpu(self, *args, **kwargs): # By default, we assume that CPU ops are compatible with CUDA ops. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b4cc6daa3c41e..3824ed3570aeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -507,11 +507,16 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -863,8 +868,13 @@ def weight_loader(self, param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + input_dim = getattr(param, "input_dim", None) input_size = loaded_weight.shape[input_dim] param_data = param_data.narrow(input_dim, 0, input_size) @@ -976,6 +986,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) # Special case for GGUF @@ -986,7 +997,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a4e0a4d509608..a6a1ed5b0dee5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -39,9 +38,6 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": - if get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "GGUF quantization hasn't supported tensor parallelism yet.") return cls() def get_quant_method(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7f9081b257705..33f24ff5d54d3 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, - zero_points: bool = False): + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" @@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor, w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and zero_points: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s if quant_type.has_bias(): w_q += quant_type.bias diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 74e534aa76a9d..28f69cfbc46bd 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -414,6 +414,8 @@ def __init__(self, config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index a11c7663263c6..73711d8eb5185 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -331,6 +331,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index ef988532ce126..f78400b0df7b3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -821,6 +821,8 @@ def __init__(self, lora_config: Optional[LoRAConfig] = None): super().__init__() + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.config = config self.model = BartModel(config, cache_config, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8cfd3c2672568..20dda2a67820d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -494,6 +494,9 @@ def __init__(self, super().__init__() + # currently all existing BLIP-2 models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 282a0f84eacb1..07ee0e3c531d0 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -276,7 +276,12 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) - self.lm_head = self.transformer.word_embeddings + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index b29ebe2f59e7b..4949d0232fabb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -356,6 +356,9 @@ def __init__( self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 0894f750e5fbf..f63cf246e510a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -321,6 +321,9 @@ def __init__( ) -> None: super().__init__() self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 7ebeca1a359ef..dca959798e8b2 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -362,6 +362,9 @@ def __init__( ): super().__init__() self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size self.transformer = DbrxModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f10977ed2c90d..7a27e1388e987 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -380,6 +380,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 7a9ee3d9477ca..e1041edf81b0a 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -331,6 +331,8 @@ def __init__( super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index ff547c2c3b8ab..5e0f8b70d4b80 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -323,6 +323,8 @@ def __init__( del lora_config # Unused. super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4f2fe0c42a3ff..bfc231282952a 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -249,7 +249,11 @@ def __init__( cache_config, quant_config, prefix="transformer") - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b30af3599aa4d..b93fb8d69b2d7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -259,7 +259,13 @@ def __init__( self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e61b4448981e8..2adecf7fa9ef8 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, - config, + config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -243,6 +243,8 @@ def __init__( config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 216458465513a..887a353df972c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -264,6 +264,8 @@ def __init__( self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ec6bea920cc3a..a550f7e6c97a1 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -37,7 +37,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -291,7 +291,11 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 46db364895b13..6433ea380cbfe 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -313,7 +313,7 @@ def forward( 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, @@ -331,7 +331,7 @@ def forward( input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each input image. - + See also: :class:`LlavaImageInputs` """ diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 729bd27c334d5..99a3c5dab39e4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -496,6 +496,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 587d2f26a2d5e..34f581ac78582 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -359,6 +359,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 812dce5d04771..8bdd52b343175 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -347,6 +347,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index b05f799e4dd2b..c0d2d537e731f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -307,7 +307,11 @@ def __init__( self.config = config self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) - self.lm_head = self.model.decoder.embed_tokens + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6923e11e288be..fab35f0b882a7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -262,6 +262,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 54f4dd2fcde0a..f31b5162aac96 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -260,6 +260,8 @@ def __init__( super().__init__() self.config = config + # lm_head use bias, cannot share word embeddings + assert not config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 98e344d483e29..df01bfa3d8e6e 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -368,6 +368,8 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -449,4 +451,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c8bb8a837c86..328f4e6fa827c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -477,6 +477,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a7485bcb489a0..b7d017d5f3ea6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -252,6 +252,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index e160c9a320820..6f838947fbf27 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -385,6 +385,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c98226d61a8a0..decbf89d27c7c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -243,6 +243,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index e9bf67d314d0a..c0bafa9367e43 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -313,6 +313,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 99ba940e5d2ab..958f6c516a2f8 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,23 +1,54 @@ -import torch - from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform +# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because +# they only indicate the build configuration, not the runtime environment. +# For example, people can install a cuda build of pytorch but run on tpu. + +is_tpu = False +try: + import torch_xla.core.xla_model as xm + xm.xla_device(devkind="TPU") + is_tpu = True +except Exception: + pass + +is_cuda = False + +try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + finally: + pynvml.nvmlShutdown() +except Exception: + pass + +is_rocm = False + try: - import libtpu -except ImportError: - libtpu = None + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + finally: + amdsmi.amdsmi_shut_down() +except Exception: + pass -if libtpu is not None: +if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first from .tpu import TpuPlatform current_platform = TpuPlatform() -elif torch.version.cuda is not None: +elif is_cuda: from .cuda import CudaPlatform current_platform = CudaPlatform() -elif torch.version.hip is not None: +elif is_rocm: from .rocm import RocmPlatform current_platform = RocmPlatform() else: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c7557dc34ff64..84301afabe9d8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -18,6 +18,12 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +if pynvml.__file__.endswith("__init__.py"): + logger.warning( + "You are using a deprecated `pynvml` package. Please install" + " `nvidia-ml-py` instead. See https://pypi.org/project/pynvml " + "for more information.") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. diff --git a/vllm/sequence.py b/vllm/sequence.py index b15955cde76cf..206da192193dc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,6 @@ Tuple, Union, cast) import msgspec -import numpy import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs @@ -1082,7 +1081,10 @@ class SamplerOutput( # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None - sampled_token_ids_numpy: Optional[numpy.ndarray] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -1257,9 +1259,7 @@ def is_last_step(self) -> bool: assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] assert first_seq_group.state is not None - num_steps = first_seq_group.state.num_steps - current_step = first_seq_group.state.current_step - return num_steps - current_step == 1 + return first_seq_group.state.remaining_steps == 1 @property def current_step(self) -> int: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index aec4847b96c35..ad6f3f313841d 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,6 +1,6 @@ from array import array from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -88,21 +88,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -145,10 +146,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -156,9 +158,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -176,23 +179,40 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -327,8 +347,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -353,24 +374,37 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1bb3b83744fec..053e9203e01eb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -11,17 +11,6 @@ from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -90,11 +79,6 @@ def __init__( observability_config=observability_config, ) - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): @@ -270,36 +254,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = \ - self.graph_runners[model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Detect exec mode assert model_input.attn_metadata is not None diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 63a00139cc09d..acf77a7349eef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,9 +646,8 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 1a56497030280..28f7f7eb069ab 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c6223a97dba10..b85f2a6f70ac0 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -123,7 +123,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -169,7 +169,23 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, diff --git a/vllm/tracing.py b/vllm/tracing.py index 8bd71b8fd9ea5..31849e2b635aa 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -8,7 +8,8 @@ logger = init_logger(__name__) -_is_otel_installed = False +_is_otel_imported = False +otel_import_error_traceback: Optional[str] = None try: from opentelemetry.context.context import Context from opentelemetry.sdk.environment_variables import ( @@ -19,8 +20,14 @@ from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator) - _is_otel_installed = True + _is_otel_imported = True except ImportError: + # Capture and format traceback to provide detailed context for the import + # error. Only the string representation of the error is retained to avoid + # memory leaks. + # See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458 + import traceback + otel_import_error_traceback = traceback.format_exc() class Context: # type: ignore pass @@ -35,14 +42,17 @@ class Tracer: # type: ignore pass -def is_otel_installed() -> bool: - return _is_otel_installed +def is_otel_available() -> bool: + return _is_otel_imported def init_tracer(instrumenting_module_name: str, otlp_traces_endpoint: str) -> Optional[Tracer]: - assert is_otel_installed(), ("OpenTelemetry packages must be installed " - "prior to initializing a tracer") + if not is_otel_available(): + raise ValueError( + "OpenTelemetry is not available. Unable to initialize " + "a tracer. Ensure OpenTelemetry packages are installed. " + f"Original error:\n{otel_import_error_traceback}") trace_provider = TracerProvider() span_exporter = get_span_exporter(otlp_traces_endpoint) @@ -70,7 +80,7 @@ def get_span_exporter(endpoint): def extract_trace_context( headers: Optional[Mapping[str, str]]) -> Optional[Context]: - if is_otel_installed(): + if is_otel_available(): headers = headers or {} return TraceContextTextMapPropagator().extract(headers) else: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 1afda0e45b702..5c700229660c0 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -6,6 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) @@ -20,7 +21,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad -from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -395,7 +396,7 @@ def _prepare_encoder_model_input_tensors( # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. - cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): block_number = seq_group_metadata.cross_block_table[ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9f27c734efd1e..793f03456e997 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,19 +13,10 @@ import torch.distributed import torch.nn as nn -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -52,8 +43,7 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available) + flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -66,7 +56,6 @@ logger = init_logger(__name__) -_PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. @@ -858,6 +847,11 @@ def __init__( self.kv_cache_dtype, self.block_size, ) if num_attn_heads else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry @@ -872,11 +866,6 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -1203,10 +1192,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() - slot_mapping.fill_(_PAD_SLOT_ID) - seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() - block_tables = torch.from_numpy(self.graph_block_tables).cuda() intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1226,102 +1211,16 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - if self.attn_backend.get_name() == "flashinfer": - # For flashinfer, different batch sizes will share the - # same workspace buffer. - decode_workspace_buffer = \ - torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - indices_buffer = torch.empty(max_batch_size * - self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - indptr_buffer = torch.empty(max_batch_size + 1, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(max_batch_size, - dtype=torch.int32, - device=self.device) - - with graph_capture() as graph_capture_context: + with self.attn_state.graph_capture( + max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): for batch_size in reversed(batch_size_capture_list): - if self.attn_backend.get_name() == "flashinfer": - _indptr_buffer = indptr_buffer[:batch_size + 1] - _last_page_len_buffer = last_page_len_buffer[: - batch_size] - - num_qo_heads = ( - self.model_config.get_num_attention_heads( - self.parallel_config)) - num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config) - if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True - else: - use_tensor_cores = False - decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, _indptr_buffer, - indices_buffer, _last_page_len_buffer, "NHD", - use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange( - 0, batch_size, dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=slot_mapping[:batch_size], - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len= - paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=True, - decode_wrapper=decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) + attn_metadata = ( + self.attn_state.graph_capture_get_metadata_for_batch( + batch_size)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1339,17 +1238,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: set(), prompt_adapter_mapping) graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name()) - - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = _indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - _last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper + self.model, self.attn_backend.get_name(), + self.attn_state.graph_clone(batch_size)) capture_inputs = { "input_ids": @@ -1476,36 +1366,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = self.graph_runners[ - model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None @@ -1613,22 +1474,17 @@ def execute_model( class CUDAGraphRunner: - def __init__(self, model: nn.Module, backend_name: str): + def __init__(self, model: nn.Module, backend_name: str, + attn_state: AttentionState): self.model = model self.backend_name = backend_name + self.attn_state = attn_state self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None - self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None - self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None - self.flashinfer_indices_buffer: Optional[torch.Tensor] = None - self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None - self.flashinfer_decode_wrapper: Optional[ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None - @property def graph(self): assert self._graph is not None @@ -1693,25 +1549,13 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. - if self.backend_name == "flashinfer": - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - **kwargs, - } - else: - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": - attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - **kwargs, - } + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + **self.attn_state.get_graph_input_buffers(attn_metadata), + **kwargs, + } if intermediate_inputs is not None: self.input_buffers.update(intermediate_inputs.tensors) if get_pp_group().is_last_rank: @@ -1739,12 +1583,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + self.attn_state.prepare_graph_input_buffers(self.input_buffers, + attn_metadata) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf4..90c39407d7266 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelRunnerInputBase") +T = TypeVar('T', bound="BroadcastableModelInput") def _add_attn_metadata_broadcastable_dict( @@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(ABC): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + +class BroadcastableModelInput(ABC): + + @abstractmethod def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some @@ -109,11 +117,25 @@ def from_broadcasted_tensor_dict( ) -> T: """ Pop fields from the given tensor_dict and populate a new instance of - ModelRunnerInputBase. + BroadcastableModelInput. """ raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000000000..521205eca05af --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,453 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + +import torch + +from vllm import _custom_ops as ops +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + + def pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output(input_metadata, copy_stream, + pinned_sampled_token_buffer, True) + self.pythonized = True + + def maybe_pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, + False) + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + return True + + +@dataclass(frozen=False) +class StatefulModelInput(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + cached_outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.cached_outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + # used to copy tensors from GPU to CPU asynchronously + self._copy_stream = torch.cuda.Stream() + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input = self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=len(frozen_model_input.seq_lens), + num_queries=len(frozen_model_input.query_lens), + ) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + current_stream = torch.cuda.current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False)) + # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is + # ready. + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = [] + for output in model_input.cached_outputs: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + assert num_seqs > 0 + assert num_queries > 0 + assert num_seqs >= num_queries + + attn_metadata = frozen_model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + attn_metadata.advance_step(num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=frozen_model_input.input_tokens, + sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, + input_positions=frozen_model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + if frozen_model_input.seq_lens is not None: + for i in range(num_queries): + frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + + return model_input + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +def _pythonize_sampler_output(model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor) -> None: + """ This function is only called when the output tensors are ready. + See ModelOutput + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + sampling_metadata = frozen_model_input.sampling_metadata + + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + samples_list): + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] = [] + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + for parent_id, next_token_id in zip(parent_ids, next_token_ids): + # TODO(will): support logprobs + # Hard coded logprob + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {next_token_id: Logprob(logprob=-1)})) + output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000000000..6a6caba9371eb --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, + StatefulModelInput) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + base_model_runner.model_config, + base_model_runner.parallel_config, + base_model_runner.scheduler_config, + base_model_runner.device_config, + base_model_runner.cache_config, + load_config=base_model_runner.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + prompt_adapter_config=base_model_runner.prompt_adapter_config, + observability_config=base_model_runner.observability_config, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached decode metadata so that it can be recomputed on + # the workers + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + model_input, worker_input = self._get_driver_input_and_broadcast( + execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 14f14e40b4c0b..01daa64b5a32f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -516,27 +516,19 @@ def execute_model( raise ValueError( "TPUModelRunner does not support multi-step execution.") - def _execute_model(*args, clone: bool = False) -> torch.Tensor: + def _execute_model(*args): """Move input args from CPU to device and execute the model.""" - def _copy_to_device(x: torch.Tensor) -> torch.Tensor: - if clone: - # When x is a slice of a CPU tensor, XLA may copy the whole - # original tensor to TPU instead of only copying x. - # To avoid this, we copy x after cloning. - x = x.clone() - return x.to(self.device) - new_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg = _copy_to_device(arg) + arg = arg.to(self.device) elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = _copy_to_device(arg.slot_mapping) + arg.slot_mapping = arg.slot_mapping.to(self.device) if getattr(arg, "block_tables", None) is not None: - arg.block_tables = _copy_to_device(arg.block_tables) + arg.block_tables = arg.block_tables.to(self.device) if getattr(arg, "context_lens", None) is not None: - arg.context_lens = _copy_to_device(arg.context_lens) + arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) return self.model(*new_args) @@ -563,13 +555,9 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor: output_token_ids = _execute_model( model_input.token_ids[None, start_idx:end_idx], model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, - model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], - model_input.p[i:i + 1], - model_input.num_samples, - kv_caches, - clone=True) + model_input.attn_metadata, model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], model_input.p[i:i + 1], + model_input.num_samples, kv_caches) # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 905052d1a9515..9fddc863548eb 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,7 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__) @@ -220,7 +222,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -237,7 +239,7 @@ def _get_worker_input_from_broadcast( def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelRunnerInputBase, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -259,7 +261,7 @@ def _get_driver_input_and_broadcast( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Prepare the inputs to ModelRunner and workers. """ diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 8a2f93c15ed5e..0bfc57a1c57de 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -137,7 +137,6 @@ def load_model(self) -> None: device_config=self.device_config, load_config=self.load_config, lora_config=self.lora_config, - multimodal_config=self.multimodal_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config, diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 0f22d67c4f254..7c8f5e0cf65ec 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -9,8 +9,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) @@ -50,6 +50,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, + observability_config: Optional[ObservabilityConfig] = None, ) -> None: assert device_config.device_type == "xpu" assert is_xpu() @@ -67,8 +68,10 @@ def __init__( self.lora_config = lora_config self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." + self.observability_config = observability_config + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." self.multimodal_config = multimodal_config @@ -183,7 +186,11 @@ def init_worker_distributed_environment(self) -> None: # dependency (libdrm and drm headers) on your system. ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "sockets") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(parallel_config.world_size)) os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) init_distributed_environment( world_size=parallel_config.world_size, rank=rank,