Skip to content

Commit ef22c6c

Browse files
shen-shanshanHwwwwwwwH
authored andcommitted
[Platform] move current_memory_usage() into platform (vllm-project#11369)
Signed-off-by: Shanshan Shen <[email protected]> Signed-off-by: hzh <[email protected]>
1 parent 1bba3f6 commit ef22c6c

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

vllm/platforms/cuda.py

+7
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
143143
if cache_config and cache_config.block_size is None:
144144
cache_config.block_size = 16
145145

146+
@classmethod
147+
def get_current_memory_usage(cls,
148+
device: Optional[torch.types.Device] = None
149+
) -> float:
150+
torch.cuda.reset_peak_memory_stats(device)
151+
return torch.cuda.max_memory_allocated(device)
152+
146153
@classmethod
147154
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
148155
kv_cache_dtype, block_size, use_v1) -> str:

vllm/platforms/interface.py

+9
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,15 @@ def is_pin_memory_available(cls) -> bool:
277277
return False
278278
return True
279279

280+
@classmethod
281+
def get_current_memory_usage(cls,
282+
device: Optional[torch.types.Device] = None
283+
) -> float:
284+
"""
285+
Return the memory usage in bytes.
286+
"""
287+
raise NotImplementedError
288+
280289
@classmethod
281290
def get_punica_wrapper(cls) -> str:
282291
"""

vllm/platforms/rocm.py

+7
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,10 @@ def verify_quantization(cls, quant: str) -> None:
157157
@classmethod
158158
def get_punica_wrapper(cls) -> str:
159159
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
160+
161+
@classmethod
162+
def get_current_memory_usage(cls,
163+
device: Optional[torch.types.Device] = None
164+
) -> float:
165+
torch.cuda.reset_peak_memory_stats(device)
166+
return torch.cuda.max_memory_allocated(device)

vllm/platforms/xpu.py

+7
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9494
def is_pin_memory_available(cls):
9595
logger.warning("Pin memory is not supported on XPU.")
9696
return False
97+
98+
@classmethod
99+
def get_current_memory_usage(cls,
100+
device: Optional[torch.types.Device] = None
101+
) -> float:
102+
torch.xpu.reset_peak_memory_stats(device)
103+
return torch.xpu.max_memory_allocated(device)

vllm/utils.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -710,13 +710,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
710710
def current_memory_usage(self) -> float:
711711
# Return the memory usage in bytes.
712712
from vllm.platforms import current_platform
713-
if current_platform.is_cuda_alike():
714-
torch.cuda.reset_peak_memory_stats(self.device)
715-
mem = torch.cuda.max_memory_allocated(self.device)
716-
elif current_platform.is_xpu():
717-
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
718-
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
719-
return mem
713+
return current_platform.get_current_memory_usage(self.device)
720714

721715
def __enter__(self):
722716
self.initial_memory = self.current_memory_usage()

0 commit comments

Comments
 (0)