-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] improve cpu offloading implementation #10609
base: main
Are you sure you want to change the base?
Changes from 18 commits
c00a93a
792650f
c568a55
a874918
6e4f195
5398836
17be34f
2dead95
2e5f95c
f2dd4a8
b2c195c
ad74693
a737dbb
608bfaf
89a2190
08d12f5
46c0d98
f2e05ce
3b14773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
|
||
import torch | ||
import torch.nn as nn | ||
from torch.func import functional_call | ||
from torch.utils._python_dispatch import TorchDispatchMode | ||
from transformers import PretrainedConfig | ||
|
||
import vllm.envs as envs | ||
|
@@ -238,6 +238,90 @@ def load_weights( | |
return autoloaded_weights | ||
|
||
|
||
class OffloadedTensorMode(TorchDispatchMode): | ||
|
||
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | ||
kwargs = {} if kwargs is None else kwargs | ||
tensor = func(*args, **kwargs) | ||
|
||
if (func is torch.ops.aten.empty.memory_format | ||
and tensor.device != "cpu"): | ||
global _CPU_OFFLOAD_BYTES, _CPU_OFFLOAD_MAX_BYTES | ||
if _CPU_OFFLOAD_BYTES < _CPU_OFFLOAD_MAX_BYTES: | ||
_CPU_OFFLOAD_BYTES += tensor.numel() * tensor.element_size() | ||
return OffloadedTensor(tensor) | ||
return tensor | ||
|
||
|
||
class OffloadedTensor(torch.Tensor): | ||
|
||
Comment on lines
+256
to
+257
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to have we might also need https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py#L25 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do generally have support for subclasses that implement both torch_function and torch_dispatch, although if you only need torch_dispatch then I agree that you probably want to disable torch_function as linked above. Let me know if you have any other questions / would like to chat more about the subclass work you're doing! |
||
def __init__(self, elem): | ||
super().__init__() | ||
|
||
if elem.device == torch.device("cpu"): | ||
# no need to offload the tensor | ||
self.offloaded_tensor = elem | ||
return | ||
|
||
# use pin_memory if possible, which helps cudagraph capture speed | ||
pin_memory = is_pin_memory_available() | ||
# `torch.empty_like` does not support `pin_memory` argument | ||
cpu_data = torch.empty_strided(size=elem.size(), | ||
stride=elem.stride(), | ||
dtype=elem.dtype, | ||
layout=elem.layout, | ||
device='cpu', | ||
pin_memory=pin_memory) | ||
cpu_data.copy_(elem) | ||
self.offloaded_tensor = cpu_data | ||
|
||
def load(self): | ||
return self.offloaded_tensor.to(self.device, non_blocking=True) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
def unwrap(x): | ||
return x.offloaded_tensor if isinstance(x, cls) else x | ||
|
||
def load(x): | ||
return x.load() if isinstance(x, cls) else x | ||
|
||
def tree_map(func, x): | ||
if isinstance(x, (list, tuple)): | ||
return type(x)(tree_map(func, y) for y in x) | ||
if isinstance(x, dict): | ||
return {k: tree_map(func, v) for k, v in x.items()} | ||
return func(x) | ||
|
||
if str(func) == "aten.detach.default": | ||
# `nn.Parameter(data)` will call `data.detach()` | ||
# and assert the returned data type is the same | ||
# as the original data | ||
return args[0] | ||
if str(func) in ["aten.copy_.default"]: | ||
# inplace or view operation on the offloaded tensor | ||
# TODO: support more inplace operations if needed | ||
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) | ||
if str(func) == "aten.uniform_.default": | ||
res = func(*tree_map(load, args), **tree_map(load, kwargs)) | ||
args[0].offloaded_tensor = res.cpu() | ||
return args[0] | ||
if func._schema.is_mutable: | ||
# the behavior of mutable ops for offloaded tensor | ||
# is not well-defined and needs case-by-case discussion | ||
raise ValueError( | ||
f"Unrecognized mutable operation {func}" | ||
" on an offloaded tensor. Please open an issue to discuss" | ||
" the support for this operation.") | ||
|
||
# for the rest of the operations, we will load the offloaded tensor | ||
# on the fly and perform the operation on the device | ||
return func(*tree_map(load, args), **tree_map(load, kwargs)) | ||
|
||
|
||
def init_vllm_registered_model( | ||
vllm_config: VllmConfig, | ||
*, | ||
|
@@ -477,62 +561,6 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None: | |
_CPU_OFFLOAD_MAX_BYTES = max_bytes | ||
|
||
|
||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: | ||
device = next(module.parameters()).device | ||
|
||
if device == torch.device("cpu"): | ||
return module | ||
|
||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES | ||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: | ||
return module | ||
|
||
pin_memory = is_pin_memory_available() | ||
|
||
# offload parameters to CPU | ||
# use pin_memory if possible, which helps cudagraph capture speed | ||
offloaded_parameters = False | ||
for p in module.parameters(): | ||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: | ||
# we use per-parameter offloading | ||
# one module might have some parameters offloaded and some not | ||
break | ||
|
||
# `torch.empty_like` does not support `pin_memory` argument | ||
cpu_data = torch.empty_strided(size=p.data.size(), | ||
stride=p.data.stride(), | ||
dtype=p.data.dtype, | ||
layout=p.data.layout, | ||
device='cpu', | ||
pin_memory=pin_memory) | ||
cpu_data.copy_(p.data) | ||
p.data = cpu_data | ||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() | ||
offloaded_parameters = True | ||
|
||
if offloaded_parameters: | ||
original_forward = module.forward | ||
|
||
def forward(*args, **kwargs): | ||
module.forward = original_forward | ||
device_state = { | ||
# here we blindly call `to(device)` | ||
# if the parameter is already on the device, it will be a no-op | ||
k: v.to(device, non_blocking=True) | ||
for k, v in module.state_dict().items() | ||
} | ||
output = functional_call(module, | ||
device_state, | ||
args=args, | ||
kwargs=kwargs) | ||
module.forward = forward | ||
return output | ||
|
||
module.forward = forward | ||
|
||
return module | ||
|
||
|
||
def make_layers( | ||
num_hidden_layers: int, | ||
layer_fn: LayerFn, | ||
|
@@ -546,11 +574,14 @@ def make_layers( | |
start_layer, end_layer = get_pp_indices(num_hidden_layers, | ||
get_pp_group().rank_in_group, | ||
get_pp_group().world_size) | ||
modules = torch.nn.ModuleList( | ||
[PPMissingLayer() for _ in range(start_layer)] + [ | ||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) | ||
for idx in range(start_layer, end_layer) | ||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) | ||
# with maybe_offload_to_cpu(): | ||
with OffloadedTensorMode(): | ||
modules = torch.nn.ModuleList( | ||
[PPMissingLayer() for _ in range(start_layer)] + [ | ||
layer_fn(prefix=f"{prefix}.{idx}") | ||
for idx in range(start_layer, end_layer) | ||
] + | ||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) | ||
return start_layer, end_layer, modules | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use
torch.device("cpu")
instead of"cpu"
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm
tensor.device != "cpu
should generally be ok