|
| 1 | +import math |
1 | 2 | from typing import Dict, List, Optional, Union |
2 | 3 |
|
3 | 4 | import torch |
@@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE): |
365 | 366 | 3. moe_finalize_scale_op: finalize the scale of the output tensor. |
366 | 367 | """ |
367 | 368 |
|
| 369 | + # To reuse pytorch memory segments allocated during graph capture. |
| 370 | + allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {} |
| 371 | + allocated_buffer_in_runtime: dict[str, torch.Tensor] = {} |
| 372 | + |
368 | 373 | def __init__( |
369 | 374 | self, |
370 | 375 | *, |
@@ -410,28 +415,102 @@ def __init__( |
410 | 415 | ) |
411 | 416 |
|
412 | 417 | def get_workspace(self, m_max: int, group_size: int): |
| 418 | + |
| 419 | + def select_buffer_with_more_elements( |
| 420 | + graph_buffer: Optional[torch.Tensor], |
| 421 | + runtime_buffer: Optional[torch.Tensor], |
| 422 | + is_capturing: bool = False |
| 423 | + ) -> tuple[Optional[torch.Tensor], bool]: |
| 424 | + if is_capturing and graph_buffer is not None: |
| 425 | + return graph_buffer, True |
| 426 | + |
| 427 | + if is_capturing == False and runtime_buffer is not None: |
| 428 | + return runtime_buffer, False |
| 429 | + |
| 430 | + if graph_buffer is None: |
| 431 | + return runtime_buffer, False |
| 432 | + |
| 433 | + if runtime_buffer is None: |
| 434 | + return graph_buffer, True |
| 435 | + |
| 436 | + def get_empty(tensor_shape: list[int], dtype: torch.dtype, |
| 437 | + cache_name: str) -> torch.Tensor: |
| 438 | + capture_graph = torch.cuda.is_current_stream_capturing() |
| 439 | + if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None: |
| 440 | + numel_like = math.prod(tensor_shape) |
| 441 | + runtime_buffer = None |
| 442 | + if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime: |
| 443 | + buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[ |
| 444 | + cache_name] |
| 445 | + numel_buffer = buffer.numel() |
| 446 | + runtime_buffer = buffer if numel_buffer >= numel_like else None |
| 447 | + |
| 448 | + graph_buffer = None |
| 449 | + # Safely get the list of candidates. Defaults to an empty list if key is missing. |
| 450 | + candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get( |
| 451 | + cache_name, []) |
| 452 | + for buffer in candidate_buffers: |
| 453 | + numel_buffer = buffer.numel() |
| 454 | + # buffer just needs to be large enough. |
| 455 | + if numel_buffer >= numel_like: |
| 456 | + graph_buffer = buffer |
| 457 | + break |
| 458 | + |
| 459 | + if capture_graph and graph_buffer is not None: |
| 460 | + return graph_buffer[0:numel_like].view(tensor_shape) |
| 461 | + else: |
| 462 | + buffer, use_graph = select_buffer_with_more_elements( |
| 463 | + graph_buffer, |
| 464 | + runtime_buffer, |
| 465 | + is_capturing=capture_graph) |
| 466 | + if buffer is not None: |
| 467 | + if not use_graph and capture_graph: |
| 468 | + # move the buffer into graph buffers since it's running in graph capturing mode. |
| 469 | + DeepGemmFusedMoE.allocated_buffer_in_runtime.pop( |
| 470 | + cache_name, None) |
| 471 | + DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault( |
| 472 | + cache_name, []).append(buffer) |
| 473 | + |
| 474 | + return buffer[0:numel_like].view(tensor_shape) |
| 475 | + |
| 476 | + # Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first. |
| 477 | + if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime: |
| 478 | + del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name] |
| 479 | + |
| 480 | + # If we get here, no suitable buffer was found in the cache. Create a new one. |
| 481 | + new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype) |
| 482 | + if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None: |
| 483 | + if capture_graph: |
| 484 | + DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault( |
| 485 | + cache_name, []).append(new_buffer) |
| 486 | + else: |
| 487 | + DeepGemmFusedMoE.allocated_buffer_in_runtime[ |
| 488 | + cache_name] = new_buffer |
| 489 | + return new_buffer |
| 490 | + |
413 | 491 | hidden_size = self.hidden_size |
414 | 492 | intermediate_size = self.intermediate_size_per_partition |
415 | 493 | num_experts = self.expert_size_per_partition |
416 | 494 |
|
417 | 495 | # create workspace |
418 | 496 | fp8_dim = max(hidden_size, intermediate_size) |
419 | | - workspace_0 = torch.empty((num_experts * m_max * fp8_dim), |
420 | | - dtype=torch.float8_e4m3fn, |
421 | | - device='cuda') |
422 | | - workspace_1 = torch.empty( |
423 | | - (num_experts * m_max * max(intermediate_size * 2, hidden_size)), |
| 497 | + workspace_0 = get_empty((num_experts * m_max * fp8_dim, ), |
| 498 | + dtype=torch.float8_e4m3fn, |
| 499 | + cache_name='workspace_0') |
| 500 | + workspace_1 = get_empty( |
| 501 | + (num_experts * m_max * max(intermediate_size * 2, hidden_size), ), |
424 | 502 | dtype=torch.bfloat16, |
425 | | - device='cuda') |
| 503 | + cache_name='workspace_1') |
426 | 504 |
|
427 | 505 | # create workspace for scaling factors |
428 | 506 | m_padded = fp8_utils.align(m_max, 4) |
429 | 507 | scale_k = fp8_utils.ceil_div(fp8_dim, group_size) |
430 | 508 | scale_k_padded = fp8_utils.align(scale_k, 4) |
431 | | - workspace_sf = torch.empty( |
432 | | - (num_experts * (scale_k_padded // 4) * m_padded), |
| 509 | + |
| 510 | + workspace_sf = get_empty( |
| 511 | + (num_experts * (scale_k_padded // 4) * m_padded, ), |
433 | 512 | dtype=torch.int32, |
434 | | - device='cuda') |
| 513 | + cache_name='workspace_sf') |
435 | 514 |
|
436 | 515 | workspace = { |
437 | 516 | "workspace_0": workspace_0, |
|
0 commit comments