Skip to content

Commit a4b7c0d

Browse files
committed
fix(eagle3):change pytorch_weights_path to speculative_model_dir
2 parents 153bc70 + b04ba97 commit a4b7c0d

File tree

27 files changed

+1079
-223
lines changed

27 files changed

+1079
-223
lines changed

cpp/tensorrt_llm/thop/allgatherOp.cpp

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -55,70 +55,61 @@ class AllgatherOp
5555
return 0;
5656
}
5757

58-
torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes)
58+
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
5959
{
6060
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
61-
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
62-
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
63-
std::vector<int64_t> outputShape = input.sizes().vec();
64-
if (sizes.has_value())
65-
{
66-
outputShape[0] = std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{});
67-
}
68-
else
69-
{
70-
outputShape[0] *= mGroup.size();
71-
}
72-
auto output = torch::empty(outputShape, input.options());
7361
bool use_nccl_allgather = !sizes.has_value()
7462
|| std::all_of(sizes.value().begin(), sizes.value().end(),
7563
[&sizes](int64_t size) { return size == sizes.value()[0]; });
76-
if (use_nccl_allgather)
77-
{
78-
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
79-
(*getDtypeMap())[type], *mNcclComm, stream));
80-
}
81-
else
82-
{
83-
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
84-
int64_t split_offset = 0;
85-
ncclGroupStart();
86-
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
87-
{
88-
auto split_size = sizes.value()[root];
89-
NCCLCHECK_THROW(ncclBroadcast(input.data_ptr(),
90-
output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).mutable_data_ptr(),
91-
numel_base * split_size, (*getDtypeMap())[type], root, *mNcclComm, stream));
92-
split_offset += split_size;
93-
}
94-
ncclGroupEnd();
95-
}
96-
return output;
97-
}
98-
99-
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
100-
{
64+
int64_t sum_sizes
65+
= sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}) : 0;
10166
std::vector<torch::Tensor> output_list;
10267
output_list.reserve(input_list.size());
103-
bool use_nccl_allgather = !sizes.has_value()
104-
|| std::all_of(sizes.value().begin(), sizes.value().end(),
105-
[&sizes](int64_t size) { return size == sizes.value()[0]; });
106-
if (use_nccl_allgather)
107-
{
108-
ncclGroupStart();
109-
}
68+
ncclGroupStart();
11069
for (auto const& input : input_list)
11170
{
112-
auto output = run(input, sizes);
71+
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
72+
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
73+
std::vector<int64_t> outputShape = input.sizes().vec();
74+
if (sizes.has_value())
75+
{
76+
outputShape[0] = sum_sizes;
77+
}
78+
else
79+
{
80+
outputShape[0] *= mGroup.size();
81+
}
82+
auto output = torch::empty(outputShape, input.options());
83+
if (use_nccl_allgather)
84+
{
85+
ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type],
86+
*mNcclComm, stream);
87+
}
88+
else
89+
{
90+
size_t numel_base
91+
= std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
92+
int64_t split_offset = 0;
93+
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
94+
{
95+
auto split_size = sizes.value()[root];
96+
ncclBroadcast(input.data_ptr(),
97+
output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).mutable_data_ptr(),
98+
numel_base * split_size, (*getDtypeMap())[type], root, *mNcclComm, stream);
99+
split_offset += split_size;
100+
}
101+
}
113102
output_list.push_back(output);
114103
}
115-
if (use_nccl_allgather)
116-
{
117-
ncclGroupEnd();
118-
}
104+
NCCLCHECK_THROW(ncclGroupEnd());
119105
return output_list;
120106
}
121107

108+
torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes)
109+
{
110+
return run_list({input}, sizes)[0];
111+
}
112+
122113
private:
123114
std::set<int> mGroup;
124115
std::shared_ptr<ncclComm_t> mNcclComm;

cpp/tensorrt_llm/thop/reducescatterOp.cpp

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -55,79 +55,69 @@ class ReducescatterOp
5555
return 0;
5656
}
5757

58-
torch::Tensor run(torch::Tensor const& input, torch::optional<torch::List<int64_t>> sizes)
58+
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
5959
{
6060
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
61-
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
62-
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
63-
std::vector<int64_t> outputShape = input.sizes().vec();
61+
bool use_nccl_reducescatter = !sizes.has_value()
62+
|| std::all_of(sizes.value().begin(), sizes.value().end(),
63+
[&sizes](int64_t size) { return size == sizes.value()[0]; });
64+
int groupRank = 0;
6465
if (sizes.has_value())
6566
{
6667
auto rank = COMM_SESSION.getRank();
67-
int groupRank = 0;
6868
for (auto const& currentRank : mGroup)
6969
{
7070
if (rank == currentRank)
7171
break;
7272
++groupRank;
7373
}
7474
TLLM_CHECK(static_cast<size_t>(groupRank) < mGroup.size());
75-
outputShape[0] = sizes.value()[groupRank];
76-
}
77-
else
78-
{
79-
outputShape[0] = outputShape[0] / mGroup.size();
8075
}
81-
auto output = torch::empty(outputShape, input.options());
82-
bool use_nccl_reducescatter = !sizes.has_value()
83-
|| std::all_of(sizes.value().begin(), sizes.value().end(),
84-
[&sizes](int64_t size) { return size == sizes.value()[0]; });
85-
if (use_nccl_reducescatter)
86-
{
87-
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
88-
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
89-
}
90-
else
76+
std::vector<torch::Tensor> output_list;
77+
output_list.reserve(input_list.size());
78+
ncclGroupStart();
79+
for (auto const& input : input_list)
9180
{
92-
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
93-
int64_t split_offset = 0;
94-
ncclGroupStart();
95-
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
81+
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
82+
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
83+
std::vector<int64_t> outputShape = input.sizes().vec();
84+
if (sizes.has_value())
9685
{
97-
auto split_size = sizes.value()[root];
98-
NCCLCHECK_THROW(
86+
outputShape[0] = sizes.value()[groupRank];
87+
}
88+
else
89+
{
90+
outputShape[0] = outputShape[0] / mGroup.size();
91+
}
92+
auto output = torch::empty(outputShape, input.options());
93+
if (use_nccl_reducescatter)
94+
{
95+
ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(), (*getDtypeMap())[type],
96+
ncclSum, *mNcclComm, stream);
97+
}
98+
else
99+
{
100+
size_t numel_base
101+
= std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
102+
int64_t split_offset = 0;
103+
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
104+
{
105+
auto split_size = sizes.value()[root];
99106
ncclReduce(input.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).data_ptr(),
100107
output.mutable_data_ptr(), numel_base * split_size, (*getDtypeMap())[type], ncclSum, root,
101-
*mNcclComm, stream));
102-
split_offset += split_size;
108+
*mNcclComm, stream);
109+
split_offset += split_size;
110+
}
103111
}
104-
ncclGroupEnd();
112+
output_list.push_back(output);
105113
}
106-
return output;
114+
NCCLCHECK_THROW(ncclGroupEnd());
115+
return output_list;
107116
}
108117

109-
std::vector<torch::Tensor> run_list(
110-
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes) noexcept
118+
torch::Tensor run(torch::Tensor const& input, torch::optional<torch::List<int64_t>> sizes)
111119
{
112-
std::vector<torch::Tensor> output_list;
113-
output_list.reserve(input_list.size());
114-
bool use_nccl_reducescatter = !sizes.has_value()
115-
|| std::all_of(sizes.value().begin(), sizes.value().end(),
116-
[&sizes](int64_t size) { return size == sizes.value()[0]; });
117-
if (use_nccl_reducescatter)
118-
{
119-
ncclGroupStart();
120-
}
121-
for (auto const& input : input_list)
122-
{
123-
auto output = run(input, sizes);
124-
output_list.push_back(output);
125-
}
126-
if (use_nccl_reducescatter)
127-
{
128-
ncclGroupEnd();
129-
}
130-
return output_list;
120+
return run_list({input}, sizes)[0];
131121
}
132122

133123
private:

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

jenkins/Build.groovy

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,18 +306,19 @@ def uploadArtifacts(artifacts, prefix = UPLOAD_PATH, retryTimes = 2, serverId =
306306
for (it in artifacts) {
307307
def uploadpath = it.key
308308
def filepath = it.value
309-
echo "uploading ${filepath} as ${uploadpath}"
310-
trtllm_utils.llmRetry(retryTimes, "uploadArtifacts", {
311-
rtUpload (
312-
serverId: serverId,
313-
spec: """{
309+
def spec = """{
314310
"files": [
315311
{
316312
"pattern": "${filepath}",
317313
"target": "${prefix}/${uploadpath}"
318314
}
319315
]
320-
}""",
316+
}"""
317+
echo "Uploading ${filepath} as ${uploadpath}. Spec: ${spec}"
318+
trtllm_utils.llmRetry(retryTimes, "uploadArtifacts", {
319+
rtUpload (
320+
serverId: serverId,
321+
spec: spec,
321322
)
322323
})
323324
}

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@
6262
from ..modules.rms_norm import RMSNorm
6363
from ..peft.lora.layer import LoraLayer
6464
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
65-
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
66-
disable_fp4_allgather)
65+
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
6766
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
6867
EagerFusionConfig, filter_weights,
6968
register_auto_model)
@@ -514,9 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
514513
if self.use_dp and self.mapping.tp_size > 1:
515514
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516515
# to reduce allreduce BW
517-
if (disable_fp4_allgather()
518-
and not self.experts.enable_alltoall) or isinstance(
519-
self.experts, TRTLLMGenFusedMoE):
516+
if isinstance(self.experts, TRTLLMGenFusedMoE):
520517
hidden_states = allgather(hidden_states,
521518
self.mapping,
522519
dim=0,

0 commit comments

Comments
 (0)