Skip to content

Commit 60bdd43

Browse files
shen-shanshanjikunshang
authored andcommitted
[Platform] move current_memory_usage() into platform (vllm-project#11369)
Signed-off-by: Shanshan Shen <[email protected]>
1 parent ae57544 commit 60bdd43

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

vllm/platforms/cuda.py

+7
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
143143
if cache_config and cache_config.block_size is None:
144144
cache_config.block_size = 16
145145

146+
@classmethod
147+
def get_current_memory_usage(cls,
148+
device: Optional[torch.types.Device] = None
149+
) -> float:
150+
torch.cuda.reset_peak_memory_stats(device)
151+
return torch.cuda.max_memory_allocated(device)
152+
146153
@classmethod
147154
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
148155
kv_cache_dtype, block_size, use_v1) -> str:

vllm/platforms/interface.py

+9
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ def is_pin_memory_available(cls) -> bool:
278278
return False
279279
return True
280280

281+
@classmethod
282+
def get_current_memory_usage(cls,
283+
device: Optional[torch.types.Device] = None
284+
) -> float:
285+
"""
286+
Return the memory usage in bytes.
287+
"""
288+
raise NotImplementedError
289+
281290
@classmethod
282291
def get_punica_wrapper(cls) -> str:
283292
"""

vllm/platforms/rocm.py

+7
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,10 @@ def verify_quantization(cls, quant: str) -> None:
157157
@classmethod
158158
def get_punica_wrapper(cls) -> str:
159159
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
160+
161+
@classmethod
162+
def get_current_memory_usage(cls,
163+
device: Optional[torch.types.Device] = None
164+
) -> float:
165+
torch.cuda.reset_peak_memory_stats(device)
166+
return torch.cuda.max_memory_allocated(device)

vllm/platforms/xpu.py

+7
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
108108
def is_pin_memory_available(cls):
109109
logger.warning("Pin memory is not supported on XPU.")
110110
return False
111+
112+
@classmethod
113+
def get_current_memory_usage(cls,
114+
device: Optional[torch.types.Device] = None
115+
) -> float:
116+
torch.xpu.reset_peak_memory_stats(device)
117+
return torch.xpu.max_memory_allocated(device)

vllm/utils.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -710,13 +710,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
710710
def current_memory_usage(self) -> float:
711711
# Return the memory usage in bytes.
712712
from vllm.platforms import current_platform
713-
if current_platform.is_cuda_alike():
714-
torch.cuda.reset_peak_memory_stats(self.device)
715-
mem = torch.cuda.max_memory_allocated(self.device)
716-
elif current_platform.is_xpu():
717-
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
718-
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
719-
return mem
713+
return current_platform.get_current_memory_usage(self.device)
720714

721715
def __enter__(self):
722716
self.initial_memory = self.current_memory_usage()

0 commit comments

Comments
 (0)