Skip to content

Commit

Permalink
Merge branch 'main' into channelwise-scales
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Jul 18, 2024
2 parents 5f2cb45 + e76466d commit c72573d
Show file tree
Hide file tree
Showing 16 changed files with 763 additions and 170 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);

void gelu_quick(torch::Tensor& out, torch::Tensor& input);

void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
131 changes: 131 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/

#include "advance_step.cuh"

namespace prepare_inputs {

//
template <int const num_threads>
__global__ void advance_step_kernel(int num_seqs, int num_queries,
int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr,
long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr,
int64_t const block_tables_stride) {
int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
return;
}

int cur_query_id = blockIdx.x * num_threads + threadIdx.x;

if (cur_query_id >= num_queries) {
return;
}

// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];

int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;

// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;

int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;

int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;

int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}

inline void verify_tensor(std::string const& name, torch::Tensor& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
if (size_0 != -1) {
size_0_cond = t.size(0) == size_0;
}

bool size_1_cond = true;
if (size_1 != -1) {
size_1_cond = t.size(1) == size_1;
}

bool is_contiguous = t.is_contiguous();
bool same_type = t.dtype() == type;

bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
if (!pass) {
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
" is not as expected: shape = [", size_0, ", ", size_1,
"], type = ", type);
}
}

void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) { // type: int

if (logging) {
printf("advance_step:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);

int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);

int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);

advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0));
}

} // namespace prepare_inputs

void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
}
19 changes: 19 additions & 0 deletions csrc/prepare_inputs/advance_step.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>

namespace prepare_inputs {

static constexpr int max_threads = 256;
static constexpr bool logging = false;

constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

} // namespace prepare_inputs
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// prepare_inputs advance_step
ops.def("advance_step", &advance_step);
ops.impl("advance_step", torch::kCUDA, &advance_step);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
54 changes: 54 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List

import pytest
import ray
from prometheus_client import REGISTRY

from vllm import EngineArgs, LLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams

MODELS = [
Expand Down Expand Up @@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels)
assert (
metric_value == num_requests), "Metrics should be collected"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.

# We have to run in a Ray task for Ray metrics to be emitted correctly
@ray.remote(num_gpus=1)
def _inner():

class _RayPrometheusStatLogger(RayPrometheusStatLogger):

def __init__(self, *args, **kwargs):
self._i = 0
super().__init__(*args, **kwargs)

def log(self, *args, **kwargs):
self._i += 1
return super().log(*args, **kwargs)

engine_args = EngineArgs(
model=model,
dtype=dtype,
disable_log_stats=False,
)
engine = LLMEngine.from_engine_args(engine_args)
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(
f"request-id-{i}",
prompt,
SamplingParams(max_tokens=max_tokens),
)
while engine.has_unfinished_requests():
engine.step()
assert logger._i > 0, ".log must be called at least once"

ray.get(_inner.remote())
1 change: 1 addition & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def get_output_from_llm_generator(
maybe_assert_ngram_worker(llm)

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

Expand Down
48 changes: 48 additions & 0 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]


@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'

k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)

seq_group_metadata_list, _, _ = create_batch(batch_size, k)

# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)

# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
12 changes: 12 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)


def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping,
block_tables)


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
Expand Down
19 changes: 19 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
Expand Down Expand Up @@ -389,6 +390,7 @@ def from_engine_args(
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -451,6 +453,7 @@ def from_engine_args(
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

Expand Down Expand Up @@ -957,3 +960,19 @@ async def is_tracing_enabled(self) -> bool:
)
else:
return self.engine.is_tracing_enabled()

def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)

def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -423,6 +424,7 @@ def from_engine_args(
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

Expand Down
Loading

0 comments on commit c72573d

Please sign in to comment.