Skip to content

Commit a6370fd

Browse files
authored
[https://nvbugs/5481434][feat] cherry-pick fix to reuse pytorch memory segments occupied by cudagraph (NVIDIA#7747)
Signed-off-by: Hui Gao <[email protected]>
1 parent fc4e6d3 commit a6370fd

File tree

1 file changed

+88
-9
lines changed

1 file changed

+88
-9
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Dict, List, Optional, Union
23

34
import torch
@@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
365366
3. moe_finalize_scale_op: finalize the scale of the output tensor.
366367
"""
367368

369+
# To reuse pytorch memory segments allocated during graph capture.
370+
allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {}
371+
allocated_buffer_in_runtime: dict[str, torch.Tensor] = {}
372+
368373
def __init__(
369374
self,
370375
*,
@@ -410,28 +415,102 @@ def __init__(
410415
)
411416

412417
def get_workspace(self, m_max: int, group_size: int):
418+
419+
def select_buffer_with_more_elements(
420+
graph_buffer: Optional[torch.Tensor],
421+
runtime_buffer: Optional[torch.Tensor],
422+
is_capturing: bool = False
423+
) -> tuple[Optional[torch.Tensor], bool]:
424+
if is_capturing and graph_buffer is not None:
425+
return graph_buffer, True
426+
427+
if is_capturing == False and runtime_buffer is not None:
428+
return runtime_buffer, False
429+
430+
if graph_buffer is None:
431+
return runtime_buffer, False
432+
433+
if runtime_buffer is None:
434+
return graph_buffer, True
435+
436+
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
437+
cache_name: str) -> torch.Tensor:
438+
capture_graph = torch.cuda.is_current_stream_capturing()
439+
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
440+
numel_like = math.prod(tensor_shape)
441+
runtime_buffer = None
442+
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
443+
buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[
444+
cache_name]
445+
numel_buffer = buffer.numel()
446+
runtime_buffer = buffer if numel_buffer >= numel_like else None
447+
448+
graph_buffer = None
449+
# Safely get the list of candidates. Defaults to an empty list if key is missing.
450+
candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get(
451+
cache_name, [])
452+
for buffer in candidate_buffers:
453+
numel_buffer = buffer.numel()
454+
# buffer just needs to be large enough.
455+
if numel_buffer >= numel_like:
456+
graph_buffer = buffer
457+
break
458+
459+
if capture_graph and graph_buffer is not None:
460+
return graph_buffer[0:numel_like].view(tensor_shape)
461+
else:
462+
buffer, use_graph = select_buffer_with_more_elements(
463+
graph_buffer,
464+
runtime_buffer,
465+
is_capturing=capture_graph)
466+
if buffer is not None:
467+
if not use_graph and capture_graph:
468+
# move the buffer into graph buffers since it's running in graph capturing mode.
469+
DeepGemmFusedMoE.allocated_buffer_in_runtime.pop(
470+
cache_name, None)
471+
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
472+
cache_name, []).append(buffer)
473+
474+
return buffer[0:numel_like].view(tensor_shape)
475+
476+
# Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first.
477+
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
478+
del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name]
479+
480+
# If we get here, no suitable buffer was found in the cache. Create a new one.
481+
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
482+
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
483+
if capture_graph:
484+
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
485+
cache_name, []).append(new_buffer)
486+
else:
487+
DeepGemmFusedMoE.allocated_buffer_in_runtime[
488+
cache_name] = new_buffer
489+
return new_buffer
490+
413491
hidden_size = self.hidden_size
414492
intermediate_size = self.intermediate_size_per_partition
415493
num_experts = self.expert_size_per_partition
416494

417495
# create workspace
418496
fp8_dim = max(hidden_size, intermediate_size)
419-
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
420-
dtype=torch.float8_e4m3fn,
421-
device='cuda')
422-
workspace_1 = torch.empty(
423-
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
497+
workspace_0 = get_empty((num_experts * m_max * fp8_dim, ),
498+
dtype=torch.float8_e4m3fn,
499+
cache_name='workspace_0')
500+
workspace_1 = get_empty(
501+
(num_experts * m_max * max(intermediate_size * 2, hidden_size), ),
424502
dtype=torch.bfloat16,
425-
device='cuda')
503+
cache_name='workspace_1')
426504

427505
# create workspace for scaling factors
428506
m_padded = fp8_utils.align(m_max, 4)
429507
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
430508
scale_k_padded = fp8_utils.align(scale_k, 4)
431-
workspace_sf = torch.empty(
432-
(num_experts * (scale_k_padded // 4) * m_padded),
509+
510+
workspace_sf = get_empty(
511+
(num_experts * (scale_k_padded // 4) * m_padded, ),
433512
dtype=torch.int32,
434-
device='cuda')
513+
cache_name='workspace_sf')
435514

436515
workspace = {
437516
"workspace_0": workspace_0,

0 commit comments

Comments
 (0)