Skip to content

Commit

Permalink
[core] cudagraph output with tensor weak reference (vllm-project#9724)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: qishuai <[email protected]>
  • Loading branch information
youkaichao authored and FerdinandZhong committed Oct 29, 2024
1 parent 1de52e7 commit a1fb710
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 28 deletions.
24 changes: 24 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@

#include "core/scalar_type.hpp"

#include <vector>

torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
// Ensure tensor is on CUDA
if (!tensor.is_cuda()) {
throw std::runtime_error("Tensor must be on CUDA device");
}

// Get the raw data pointer
void* data_ptr = tensor.data_ptr();

// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();

// Get tensor options (dtype, device)
auto options = tensor.options();

// Create a new tensor from the raw data pointer
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);

return new_tensor;
}

void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
Expand Down
3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops

ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
Expand Down
9 changes: 9 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,3 +1479,12 @@ def __iter__(self):

def __len__(self):
return len(self._factory)


def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
return torch.ops._C.weak_ref_tensor(tensor)
42 changes: 14 additions & 28 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
supports_dynamo, weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -1426,12 +1426,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
dtype=self.model_config.dtype,
device=self.device)

# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
None
] * self.parallel_config.pipeline_parallel_size

graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
Expand Down Expand Up @@ -1474,12 +1468,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
input_tokens[:batch_size],
"positions":
input_positions[..., :batch_size],
"hidden_or_intermediate_states":
hidden_or_intermediate_states[
virtual_engine] # type: ignore
[:batch_size]
if hidden_or_intermediate_states[virtual_engine]
is not None else None,
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
Expand Down Expand Up @@ -1762,15 +1750,13 @@ def capture(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
torch.Tensor]],
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
):
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
Expand Down Expand Up @@ -1799,20 +1785,21 @@ def capture(
intermediate_tensors=intermediate_inputs,
**kwargs,
)
if hidden_or_intermediate_states is not None:
if get_pp_group().is_last_rank:
hidden_or_intermediate_states.copy_(
output_hidden_or_intermediate_states)
else:
for key in hidden_or_intermediate_states.tensors:
hidden_or_intermediate_states[key].copy_(
output_hidden_or_intermediate_states[key])
else:
hidden_or_intermediate_states = (

if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
hidden_or_intermediate_states = weak_ref_tensor(
output_hidden_or_intermediate_states)
elif isinstance(output_hidden_or_intermediate_states,
IntermediateTensors):
hidden_or_intermediate_states = IntermediateTensors(
tensors={
key: weak_ref_tensor(value)
for key, value in
output_hidden_or_intermediate_states.tensors.items()
})

del output_hidden_or_intermediate_states
# make sure `output_hidden_states` is deleted
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
Expand All @@ -1837,7 +1824,6 @@ def capture(
}
else:
self.output_buffers = hidden_or_intermediate_states
return hidden_or_intermediate_states

def forward(
self,
Expand Down

0 comments on commit a1fb710

Please sign in to comment.