@@ -392,8 +392,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
392392 std::vector<int64_t > output_shape = {num_rows, unpadded_hidden_size_val};
393393 auto output = torch::empty (output_shape, input.options ().dtype (mOutputDtype ));
394394
395- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
396- static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
395+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
396+ static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream );
397397
398398 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
399399 kernels::MoeMinLatencyParams min_latency_params{};
@@ -553,8 +553,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
553553 min_latency_params.experts_to_token_score = static_cast <float *>(experts_to_token_score.data_ptr ());
554554 min_latency_params.active_expert_global_ids = static_cast <int *>(active_expert_global_ids.data_ptr ());
555555
556- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
557- static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
556+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
557+ static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream );
558558
559559 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
560560
@@ -709,6 +709,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
709709 // e.g. 16 nvfp4 elements are packed into a single int64 element
710710 int64_t mInnerDimMultiplier ;
711711 char * mProfileWorkspace = nullptr ;
712+ WorkspaceInfo workspace_info;
712713
713714 bool mUseDeepSeekFP8BlockScaling = false ;
714715 bool mUseW4GroupScaling = false ;
@@ -757,9 +758,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
757758 mKernelRunner ->setTactic (best_gemm1_profile, best_gemm2_profile);
758759 }
759760
760- WorkspaceInfo getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
761+ WorkspaceInfo const & getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
761762 int num_experts, int experts_per_token, ActivationType activation_type,
762- kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode)
763+ kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode, cudaStream_t stream )
763764 {
764765 size_t moe_workspace_size = mKernelRunner ->getWorkspaceSize (num_rows, hidden_size, inter_size, num_experts,
765766 experts_per_token, activation_type, parallelismConfig, /* use_lora */ false , mUseDeepSeekFP8BlockScaling ,
@@ -768,15 +769,29 @@ class FusedMoeRunner : public torch::CustomClassHolder
768769
769770 std::vector<size_t > workspaces{moe_workspace_size, src_to_dest_map_size};
770771
771- size_t total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
772+ int64_t const total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
772773
773- WorkspaceInfo info{};
774- info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
775- torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
776- info.src_to_dest_map
777- = common::nextWorkspacePtr (static_cast <int8_t *>(info.workspace .data_ptr ()), moe_workspace_size);
774+ bool is_capturing = tensorrt_llm::common::isCapturing (stream);
775+ // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
776+ if (is_capturing || workspace_info.workspace .numel () < total_workspace_size)
777+ {
778+ if (is_capturing)
779+ {
780+ TLLM_LOG_DEBUG (
781+ " Allocating MoE workspace with %ld bytes size during cuda graph capture" , total_workspace_size);
782+ }
783+ else
784+ {
785+ TLLM_LOG_DEBUG (" MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes" ,
786+ workspace_info.workspace .numel (), total_workspace_size);
787+ }
788+ workspace_info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
789+ torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
790+ }
791+ workspace_info.src_to_dest_map
792+ = common::nextWorkspacePtr (static_cast <int8_t *>(workspace_info.workspace .data_ptr ()), moe_workspace_size);
778793
779- return info ;
794+ return workspace_info ;
780795 }
781796
782797 kernels::QuantParams getQuantParams (int64_t const num_experts_on_rank, int64_t const hidden_size,
0 commit comments