Skip to content

Commit ca790ff

Browse files
committed
[Feature][XPU] support MTP
1 parent 667dc4a commit ca790ff

File tree

20 files changed

+1031
-494
lines changed

20 files changed

+1031
-494
lines changed

custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
2424
#endif
2525

26+
#define GET_OUTPUT_DEBUG
27+
#define SAVE_WITH_OUTPUT_DEBUG
28+
2629
#include "speculate_msg.h"
2730

2831
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,

custom_ops/xpu_ops/src/ops/block_attn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
623623
: quant_v_scale_inv,
624624
nullptr, // o_maxptr
625625
param.head_dim); // vo_head_dim
626-
PD_CHECK(0, "speculative_attention unimplemented");
626+
// PD_CHECK(0, "speculative_attention unimplemented");
627627
PD_CHECK(ret == api::SUCCESS,
628628
"xfa::speculative_attention_decoder failed.");
629629
if (!Eq_len) {
Lines changed: 111 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,97 +12,135 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include <xft/xdnn_plugin.h>
1517
#include "paddle/extension.h"
1618
#include "xpu/plugin.h"
17-
#include <paddle/phi/backends/xpu/xpu_context.h>
18-
std::vector<paddle::Tensor>
19-
GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed]
20-
const paddle::Tensor &cum_offsets, // [bsz, 1]
21-
const paddle::Tensor &encoder_seq_lod,
22-
const paddle::Tensor &encoder_batch_map,
23-
const paddle::Tensor &decoder_batch_map,
24-
const paddle::Tensor &encoder_seq_lod_cpu,
25-
const paddle::Tensor &encoder_batch_map_cpu,
26-
const paddle::Tensor &decoder_batch_map_cpu,
27-
const paddle::Tensor &enc_batch_tensor,
28-
const paddle::Tensor &dec_batch_tensor,
29-
const paddle::optional<paddle::Tensor> &output_padding_offset,
30-
int max_input_length) {
31-
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
32-
auto dev_ctx =
33-
paddle::experimental::DeviceContextPool::Instance().Get(place);
34-
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
35-
using XPUType =
36-
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
37-
typedef paddle::bfloat16 data_t;
38-
const int dim = tmp_out.dims()[1];
39-
const int bsz = cum_offsets.shape()[0];
40-
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
41-
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
4219

43-
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
44-
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
45-
enc_batch + 1, const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
46-
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
47-
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()), enc_batch,
48-
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
49-
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
50-
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()), dec_batch,
51-
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
20+
std::vector<paddle::Tensor> GatherNextToken(
21+
const paddle::Tensor& x, // [token_num, dim_embed]
22+
const paddle::Tensor& cum_offsets, // [bsz, 1]
23+
const paddle::Tensor& encoder_seq_lod,
24+
const paddle::Tensor& encoder_batch_map,
25+
const paddle::Tensor& decoder_batch_map,
26+
const paddle::Tensor& encoder_seq_lod_cpu,
27+
const paddle::Tensor& encoder_batch_map_cpu,
28+
const paddle::Tensor& decoder_batch_map_cpu,
29+
const paddle::Tensor& len_info_cpu,
30+
const paddle::optional<paddle::Tensor>& output_padding_offset) {
31+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
32+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
33+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
34+
using XPUType =
35+
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
36+
typedef paddle::bfloat16 data_t;
37+
const int dim = x.dims()[1];
38+
const int token_num = x.shape()[0];
39+
const int bsz = cum_offsets.shape()[0];
40+
int enc_batch = len_info_cpu.data<int32_t>()[0];
41+
int dec_batch = len_info_cpu.data<int32_t>()[1];
5242

53-
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
43+
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
44+
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
45+
enc_batch + 1,
46+
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
47+
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
48+
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
49+
enc_batch,
50+
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())};
51+
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
52+
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
53+
dec_batch,
54+
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
5455

56+
paddle::Tensor out;
57+
std::vector<int> encode_iota_lod_cpu(enc_batch);
58+
if (output_padding_offset) {
59+
int need_delete_token_num = 0;
60+
if (enc_batch > 0) {
61+
need_delete_token_num =
62+
encoder_seq_lod_cpu.data<int32_t>()[enc_batch] - enc_batch;
63+
std::iota(encode_iota_lod_cpu.begin(), encode_iota_lod_cpu.end(), 0);
64+
encoder_batch_map_vp.cpu =
65+
const_cast<const int32_t*>(encode_iota_lod_cpu.data());
66+
encoder_batch_map_vp.len = enc_batch;
67+
encoder_batch_map_vp.xpu = nullptr;
68+
}
69+
out = paddle::empty(
70+
{token_num - need_delete_token_num, dim}, x.type(), x.place());
71+
} else {
72+
out = paddle::empty({bsz, dim}, x.type(), x.place());
73+
}
74+
if (x.shape()[0] == 0) {
75+
return {out};
76+
}
77+
78+
if (output_padding_offset && enc_batch <= 0) {
79+
out = x.copy_to(x.place(), false);
80+
} else {
5581
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
5682
xpu_ctx->x_context(),
57-
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
58-
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
59-
encoder_batch_map_vp, decoder_batch_map_vp, dim);
60-
return {out};
83+
reinterpret_cast<const XPUType*>(x.data<data_t>()),
84+
reinterpret_cast<XPUType*>(out.data<data_t>()),
85+
encoder_seqs_lods_vp,
86+
encoder_batch_map_vp,
87+
decoder_batch_map_vp,
88+
dim);
89+
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
90+
}
91+
return {out};
6192
}
6293

6394
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
64-
const std::vector<int64_t> &tmp_out_shape,
65-
const std::vector<int64_t> &cum_offsets_shape,
66-
const std::vector<int64_t> &encoder_seq_lod_shape,
67-
const std::vector<int64_t> &encoder_batch_map_shape,
68-
const std::vector<int64_t> &decoder_batch_map_shape,
69-
const std::vector<int64_t> &encoder_seq_lod_cpu_shape,
70-
const std::vector<int64_t> &encoder_batch_map_cpu_shape,
71-
const std::vector<int64_t> &decoder_batch_map_cpu_shape,
72-
const std::vector<int64_t> &enc_batch_tensor_shape,
73-
const std::vector<int64_t> &dec_batch_tensor_shape,
74-
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
75-
if (output_padding_offset_shape) {
76-
PD_THROW("speculative decoding is not supported in XPU.");
77-
}
95+
const std::vector<int64_t>& x_shape,
96+
const std::vector<int64_t>& cum_offsets_shape,
97+
const std::vector<int64_t>& encoder_seq_lod_shape,
98+
const std::vector<int64_t>& encoder_batch_map_shape,
99+
const std::vector<int64_t>& decoder_batch_map_shape,
100+
const std::vector<int64_t>& encoder_seq_lod_cpu_shape,
101+
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
102+
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
103+
const std::vector<int64_t>& len_info_cpu_shape,
104+
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
105+
// if (output_padding_offset_shape) {
106+
// PD_THROW("speculative decoding is not supported in XPU.");
107+
// }
108+
int64_t bsz = cum_offsets_shape[0];
109+
int64_t dim_embed = x_shape[1];
110+
if (output_padding_offset_shape) {
111+
return {{-1, dim_embed}};
112+
} else {
78113
int64_t bsz = cum_offsets_shape[0];
79-
int64_t dim_embed = tmp_out_shape[1];
80114
return {{bsz, dim_embed}};
115+
}
81116
}
82117

83118
std::vector<paddle::DataType> GatherNextTokenInferDtype(
84-
const paddle::DataType &tmp_out_dtype,
85-
const paddle::DataType &cum_offsets_dtype,
86-
const paddle::DataType &encoder_seq_lod_dtype,
87-
const paddle::DataType &encoder_batch_map_dtype,
88-
const paddle::DataType &decoder_batch_map_dtype,
89-
const paddle::DataType &encoder_seq_lod_cpu_dtype,
90-
const paddle::DataType &encoder_batch_map_cpu_dtype,
91-
const paddle::DataType &decoder_batch_map_cpu_dtype,
92-
const paddle::DataType &enc_batch_tensor_dtype,
93-
const paddle::DataType &dec_batch_tensor_dtype,
94-
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
95-
return {tmp_out_dtype};
119+
const paddle::DataType& x_dtype,
120+
const paddle::DataType& cum_offsets_dtype,
121+
const paddle::DataType& encoder_seq_lod_dtype,
122+
const paddle::DataType& encoder_batch_map_dtype,
123+
const paddle::DataType& decoder_batch_map_dtype,
124+
const paddle::DataType& encoder_seq_lod_cpu_dtype,
125+
const paddle::DataType& encoder_batch_map_cpu_dtype,
126+
const paddle::DataType& decoder_batch_map_cpu_dtype,
127+
const paddle::DataType& len_info_cpu_dtype,
128+
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
129+
return {x_dtype};
96130
}
97131

98132
PD_BUILD_OP(gather_next_token)
99-
.Inputs({"tmp_out", "cum_offsets", "encoder_seq_lod", "encoder_batch_map",
100-
"decoder_batch_map", "encoder_seq_lod_cpu",
101-
"encoder_batch_map_cpu", "decoder_batch_map_cpu",
102-
"enc_batch_tensor", "dec_batch_tensor",
133+
.Inputs({"x",
134+
"cum_offsets",
135+
"encoder_seq_lod",
136+
"encoder_batch_map",
137+
"decoder_batch_map",
138+
"encoder_seq_lod_cpu",
139+
"encoder_batch_map_cpu",
140+
"decoder_batch_map_cpu",
141+
"len_info_cpu",
103142
paddle::Optional("output_padding_offset")})
104143
.Outputs({"out"})
105-
.Attrs({"max_input_length: int"})
106144
.SetKernelFn(PD_KERNEL(GatherNextToken))
107145
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
108-
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
146+
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));

custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include "paddle/phi/core/enforce.h"
1818
#include "xpu/plugin.h"
1919

20+
#ifndef PD_BUILD_STATIC_OP
21+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
22+
#endif
23+
2024
namespace api = baidu::xpu::api;
2125
void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens,
2226
const paddle::Tensor& input_ids,
@@ -99,7 +103,7 @@ void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens,
99103
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
100104
}
101105

102-
PD_BUILD_OP(draft_model_preprocess_v2)
106+
PD_BUILD_STATIC_OP(draft_model_preprocess_v2)
103107
.Inputs({"draft_tokens",
104108
"input_ids",
105109
"stop_flags",
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "xpu/plugin.h"
18+
19+
#ifndef PD_BUILD_STATIC_OP
20+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
21+
#endif
22+
23+
std::vector<paddle::Tensor> SpeculateGetPaddingOffsetV2(
24+
const paddle::Tensor& input_ids,
25+
const paddle::Tensor& draft_tokens,
26+
const paddle::Tensor& cum_offsets,
27+
const paddle::Tensor& token_num,
28+
const paddle::Tensor& seq_len,
29+
const paddle::Tensor& seq_lens_encoder) {
30+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
31+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
32+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
33+
34+
std::vector<int64_t> input_ids_shape = input_ids.shape();
35+
const int bsz = seq_len.shape()[0];
36+
const int seq_length = input_ids_shape[1];
37+
const int max_draft_tokens = draft_tokens.shape()[1];
38+
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
39+
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
40+
41+
const int token_num_data = cpu_token_num.data<int64_t>()[0];
42+
auto x_remove_padding = paddle::empty(
43+
{token_num_data}, paddle::DataType::INT64, input_ids.place());
44+
auto padding_offset = paddle::empty(
45+
{token_num_data}, paddle::DataType::INT32, input_ids.place());
46+
auto batch_id_per_token = paddle::empty(
47+
{token_num_data}, paddle::DataType::INT32, input_ids.place());
48+
auto cu_seqlens_q =
49+
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
50+
auto cu_seqlens_k =
51+
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
52+
53+
PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous");
54+
PD_CHECK(draft_tokens.is_contiguous(),
55+
"Draft tokens tensor must be contiguous");
56+
PD_CHECK(cum_offsets.is_contiguous(),
57+
"Cum offsets tensor must be contiguous");
58+
PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous");
59+
60+
int r = baidu::xpu::api::plugin::speculate_get_padding_offset_v2(
61+
xpu_ctx->x_context(),
62+
batch_id_per_token.data<int>(),
63+
cum_offsets_out.data<int>(),
64+
cu_seqlens_q.data<int>(),
65+
cu_seqlens_k.data<int>(),
66+
cum_offsets.data<int>(),
67+
seq_len.data<int>(),
68+
seq_length,
69+
bsz);
70+
PD_CHECK(r == 0, "XPU speculate_get_padding_offset_v2 failed");
71+
72+
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
73+
xpu_ctx->x_context(),
74+
x_remove_padding.data<int64_t>(),
75+
input_ids.data<int64_t>(),
76+
draft_tokens.data<int64_t>(),
77+
seq_len.data<int>(),
78+
seq_lens_encoder.data<int>(),
79+
cum_offsets_out.data<int>(),
80+
seq_length,
81+
max_draft_tokens,
82+
bsz,
83+
token_num_data);
84+
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
85+
86+
return {x_remove_padding,
87+
batch_id_per_token,
88+
cu_seqlens_q,
89+
cu_seqlens_k}; // , enc_token_num, dec_token_num};
90+
}
91+
92+
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetV2InferShape(
93+
const std::vector<int64_t>& input_ids_shape,
94+
const std::vector<int64_t>& draft_tokens_shape,
95+
const std::vector<int64_t>& cum_offsets_shape,
96+
const std::vector<int64_t>& token_num_shape,
97+
const std::vector<int64_t>& seq_len_shape,
98+
const std::vector<int64_t>& seq_lens_encoder_shape) {
99+
int64_t bsz = seq_len_shape[0];
100+
int64_t seq_len = input_ids_shape[1];
101+
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
102+
}
103+
104+
std::vector<paddle::DataType> SpeculateGetPaddingOffsetV2InferDtype(
105+
const paddle::DataType& input_ids_dtype,
106+
const paddle::DataType& draft_tokens_dtype,
107+
const paddle::DataType& cum_offsets_dtype,
108+
const paddle::DataType& token_num_dtype,
109+
const paddle::DataType& seq_len_dtype,
110+
const paddle::DataType& seq_lens_encoder_dtype) {
111+
return {input_ids_dtype,
112+
seq_len_dtype,
113+
seq_len_dtype,
114+
seq_len_dtype,
115+
seq_len_dtype};
116+
}
117+
118+
PD_BUILD_STATIC_OP(speculate_get_padding_offset_v2)
119+
.Inputs({"input_ids",
120+
"draft_tokens",
121+
"cum_offsets",
122+
"token_num",
123+
"seq_len",
124+
"seq_lens_encoder"})
125+
.Outputs({"x_remove_padding",
126+
"batch_id_per_token",
127+
"cu_seqlens_q",
128+
"cu_seqlens_k"})
129+
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffsetV2))
130+
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetV2InferShape))
131+
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetV2InferDtype));

custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
3535
const paddle::Tensor& not_need_stop,
3636
int64_t rank_id,
3737
int msg_queue_id,
38-
int save_each_rank) {
39-
// printf("enter save output");
38+
bool save_each_rank) {
4039
if (!save_each_rank && rank_id > 0) {
4140
return;
4241
}

0 commit comments

Comments
 (0)