Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] improve cpu offloading implementation #10609

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
18 changes: 16 additions & 2 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,19 @@


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B", [],
["--cpu-offload-gb", "1"])
compare_two_settings("meta-llama/Llama-3.1-8B", [],
["--cpu-offload-gb", "2"])


#
#
# def test_cpu_offload_gptq():
# # Test GPTQ Marlin
# compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", [],
# ["--cpu-offload-gb", "1"],
# max_wait_seconds=480)
# # Test GPTQ
# compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
# ["--quantization", "gptq"],
# ["--quantization", "gptq", "--cpu-offload-gb", "1"],
# max_wait_seconds=480)
155 changes: 93 additions & 62 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.nn as nn
from torch.func import functional_call
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig

import vllm.envs as envs
Expand Down Expand Up @@ -238,6 +238,90 @@ def load_weights(
return autoloaded_weights


class OffloadedTensorMode(TorchDispatchMode):

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
tensor = func(*args, **kwargs)

if (func is torch.ops.aten.empty.memory_format
and tensor.device != "cpu"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use torch.device("cpu") instead of "cpu"?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm tensor.device != "cpu should generally be ok

global _CPU_OFFLOAD_BYTES, _CPU_OFFLOAD_MAX_BYTES
if _CPU_OFFLOAD_BYTES < _CPU_OFFLOAD_MAX_BYTES:
_CPU_OFFLOAD_BYTES += tensor.numel() * tensor.element_size()
return OffloadedTensor(tensor)
return tensor


class OffloadedTensor(torch.Tensor):

Comment on lines +256 to +257
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do generally have support for subclasses that implement both torch_function and torch_dispatch, although if you only need torch_dispatch then I agree that you probably want to disable torch_function as linked above.

Let me know if you have any other questions / would like to chat more about the subclass work you're doing!

def __init__(self, elem):
super().__init__()

if elem.device == torch.device("cpu"):
# no need to offload the tensor
self.offloaded_tensor = elem
return

# use pin_memory if possible, which helps cudagraph capture speed
pin_memory = is_pin_memory_available()
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=elem.size(),
stride=elem.stride(),
dtype=elem.dtype,
layout=elem.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(elem)
self.offloaded_tensor = cpu_data

def load(self):
return self.offloaded_tensor.to(self.device, non_blocking=True)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

def unwrap(x):
return x.offloaded_tensor if isinstance(x, cls) else x

def load(x):
return x.load() if isinstance(x, cls) else x

def tree_map(func, x):
if isinstance(x, (list, tuple)):
return type(x)(tree_map(func, y) for y in x)
if isinstance(x, dict):
return {k: tree_map(func, v) for k, v in x.items()}
return func(x)

if str(func) == "aten.detach.default":
# `nn.Parameter(data)` will call `data.detach()`
# and assert the returned data type is the same
# as the original data
return args[0]
if str(func) in ["aten.copy_.default"]:
# inplace or view operation on the offloaded tensor
# TODO: support more inplace operations if needed
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if str(func) == "aten.uniform_.default":
res = func(*tree_map(load, args), **tree_map(load, kwargs))
args[0].offloaded_tensor = res.cpu()
return args[0]
if func._schema.is_mutable:
# the behavior of mutable ops for offloaded tensor
# is not well-defined and needs case-by-case discussion
raise ValueError(
f"Unrecognized mutable operation {func}"
" on an offloaded tensor. Please open an issue to discuss"
" the support for this operation.")

# for the rest of the operations, we will load the offloaded tensor
# on the fly and perform the operation on the device
return func(*tree_map(load, args), **tree_map(load, kwargs))


def init_vllm_registered_model(
vllm_config: VllmConfig,
*,
Expand Down Expand Up @@ -477,62 +561,6 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
_CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device

if device == torch.device("cpu"):
return module

global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()

# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module


def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
Expand All @@ -546,11 +574,14 @@ def make_layers(
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
# with maybe_offload_to_cpu():
with OffloadedTensorMode():
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
layer_fn(prefix=f"{prefix}.{idx}")
for idx in range(start_layer, end_layer)
] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules


Expand Down
Loading