1- // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
1+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ #include < paddle/phi/backends/xpu/xpu_context.h>
16+ #include < xft/xdnn_plugin.h>
1517#include " paddle/extension.h"
1618#include " xpu/plugin.h"
17- #include < paddle/phi/backends/xpu/xpu_context.h>
18- std::vector<paddle::Tensor>
19- GatherNextToken (const paddle::Tensor &tmp_out, // [token_num, dim_embed]
20- const paddle::Tensor &cum_offsets, // [bsz, 1]
21- const paddle::Tensor &encoder_seq_lod,
22- const paddle::Tensor &encoder_batch_map,
23- const paddle::Tensor &decoder_batch_map,
24- const paddle::Tensor &encoder_seq_lod_cpu,
25- const paddle::Tensor &encoder_batch_map_cpu,
26- const paddle::Tensor &decoder_batch_map_cpu,
27- const paddle::Tensor &enc_batch_tensor,
28- const paddle::Tensor &dec_batch_tensor,
29- const paddle::optional<paddle::Tensor> &output_padding_offset,
30- int max_input_length) {
31- phi::XPUPlace place (phi::backends::xpu::GetXPUCurrentDeviceId ());
32- auto dev_ctx =
33- paddle::experimental::DeviceContextPool::Instance ().Get (place);
34- auto xpu_ctx = static_cast <const phi::XPUContext *>(dev_ctx);
35- using XPUType =
36- typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
37- typedef paddle::bfloat16 data_t ;
38- const int dim = tmp_out.dims ()[1 ];
39- const int bsz = cum_offsets.shape ()[0 ];
40- int enc_batch = enc_batch_tensor.data <int32_t >()[0 ];
41- int dec_batch = dec_batch_tensor.data <int32_t >()[0 ];
4219
43- baidu::xpu::api::VectorParam<int32_t > encoder_seqs_lods_vp{
44- const_cast <int32_t *>(encoder_seq_lod_cpu.data <int32_t >()),
45- enc_batch + 1 , const_cast <int32_t *>(encoder_seq_lod.data <int32_t >())};
46- baidu::xpu::api::VectorParam<int32_t > encoder_batch_map_vp{
47- const_cast <int32_t *>(encoder_batch_map_cpu.data <int32_t >()), enc_batch,
48- const_cast <int32_t *>(encoder_batch_map.data <int32_t >())};
49- baidu::xpu::api::VectorParam<int32_t > decoder_batch_map_vp{
50- const_cast <int32_t *>(decoder_batch_map_cpu.data <int32_t >()), dec_batch,
51- const_cast <int32_t *>(decoder_batch_map.data <int32_t >())};
20+ std::vector<paddle::Tensor> GatherNextToken (
21+ const paddle::Tensor& x, // [token_num, dim_embed]
22+ const paddle::Tensor& cum_offsets, // [bsz, 1]
23+ const paddle::Tensor& encoder_seq_lod,
24+ const paddle::Tensor& encoder_batch_map,
25+ const paddle::Tensor& decoder_batch_map,
26+ const paddle::Tensor& encoder_seq_lod_cpu,
27+ const paddle::Tensor& encoder_batch_map_cpu,
28+ const paddle::Tensor& decoder_batch_map_cpu,
29+ const paddle::Tensor& len_info_cpu,
30+ const paddle::optional<paddle::Tensor>& output_padding_offset) {
31+ phi::XPUPlace place (phi::backends::xpu::GetXPUCurrentDeviceId ());
32+ auto dev_ctx = paddle::experimental::DeviceContextPool::Instance ().Get (place);
33+ auto xpu_ctx = static_cast <const phi::XPUContext*>(dev_ctx);
34+ using XPUType =
35+ typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
36+ typedef paddle::bfloat16 data_t ;
37+ const int dim = x.dims ()[1 ];
38+ const int token_num = x.shape ()[0 ];
39+ const int bsz = cum_offsets.shape ()[0 ];
40+ int enc_batch = len_info_cpu.data <int32_t >()[0 ];
41+ int dec_batch = len_info_cpu.data <int32_t >()[1 ];
5242
53- auto out = paddle::full ({bsz, dim}, -2 , tmp_out.type (), tmp_out.place ());
43+ baidu::xpu::api::VectorParam<int32_t > encoder_seqs_lods_vp{
44+ const_cast <int32_t *>(encoder_seq_lod_cpu.data <int32_t >()),
45+ enc_batch + 1 ,
46+ const_cast <int32_t *>(encoder_seq_lod.data <int32_t >())};
47+ baidu::xpu::api::VectorParam<int32_t > encoder_batch_map_vp{
48+ const_cast <int32_t *>(encoder_batch_map_cpu.data <int32_t >()),
49+ enc_batch,
50+ const_cast <int32_t *>(encoder_batch_map.data <int32_t >())};
51+ baidu::xpu::api::VectorParam<int32_t > decoder_batch_map_vp{
52+ const_cast <int32_t *>(decoder_batch_map_cpu.data <int32_t >()),
53+ dec_batch,
54+ const_cast <int32_t *>(decoder_batch_map.data <int32_t >())};
5455
56+ paddle::Tensor out;
57+ std::vector<int > encode_iota_lod_cpu (enc_batch);
58+ if (output_padding_offset) {
59+ int need_delete_token_num = 0 ;
60+ if (enc_batch > 0 ) {
61+ need_delete_token_num =
62+ encoder_seq_lod_cpu.data <int32_t >()[enc_batch] - enc_batch;
63+ std::iota (encode_iota_lod_cpu.begin (), encode_iota_lod_cpu.end (), 0 );
64+ encoder_batch_map_vp.cpu =
65+ const_cast <const int32_t *>(encode_iota_lod_cpu.data ());
66+ encoder_batch_map_vp.len = enc_batch;
67+ encoder_batch_map_vp.xpu = nullptr ;
68+ }
69+ out = paddle::empty (
70+ {token_num - need_delete_token_num, dim}, x.type (), x.place ());
71+ } else {
72+ out = paddle::empty ({bsz, dim}, x.type (), x.place ());
73+ }
74+ if (x.shape ()[0 ] == 0 ) {
75+ return {out};
76+ }
77+
78+ if (output_padding_offset && enc_batch <= 0 ) {
79+ out = x.copy_to (x.place (), false );
80+ } else {
5581 int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
5682 xpu_ctx->x_context (),
57- reinterpret_cast <const XPUType *>(tmp_out.data <data_t >()),
58- reinterpret_cast <XPUType *>(out.data <data_t >()), encoder_seqs_lods_vp,
59- encoder_batch_map_vp, decoder_batch_map_vp, dim);
60- return {out};
83+ reinterpret_cast <const XPUType*>(x.data <data_t >()),
84+ reinterpret_cast <XPUType*>(out.data <data_t >()),
85+ encoder_seqs_lods_vp,
86+ encoder_batch_map_vp,
87+ decoder_batch_map_vp,
88+ dim);
89+ PD_CHECK (r == 0 , " xpu::plugin::gather_next_token failed." );
90+ }
91+ return {out};
6192}
6293
6394std::vector<std::vector<int64_t >> GatherNextTokenInferShape (
64- const std::vector<int64_t > &tmp_out_shape,
65- const std::vector<int64_t > &cum_offsets_shape,
66- const std::vector<int64_t > &encoder_seq_lod_shape,
67- const std::vector<int64_t > &encoder_batch_map_shape,
68- const std::vector<int64_t > &decoder_batch_map_shape,
69- const std::vector<int64_t > &encoder_seq_lod_cpu_shape,
70- const std::vector<int64_t > &encoder_batch_map_cpu_shape,
71- const std::vector<int64_t > &decoder_batch_map_cpu_shape,
72- const std::vector<int64_t > &enc_batch_tensor_shape,
73- const std::vector<int64_t > &dec_batch_tensor_shape,
74- const paddle::optional<std::vector<int64_t >> &output_padding_offset_shape) {
75- if (output_padding_offset_shape) {
76- PD_THROW (" speculative decoding is not supported in XPU." );
77- }
95+ const std::vector<int64_t >& x_shape,
96+ const std::vector<int64_t >& cum_offsets_shape,
97+ const std::vector<int64_t >& encoder_seq_lod_shape,
98+ const std::vector<int64_t >& encoder_batch_map_shape,
99+ const std::vector<int64_t >& decoder_batch_map_shape,
100+ const std::vector<int64_t >& encoder_seq_lod_cpu_shape,
101+ const std::vector<int64_t >& encoder_batch_map_cpu_shape,
102+ const std::vector<int64_t >& decoder_batch_map_cpu_shape,
103+ const std::vector<int64_t >& len_info_cpu_shape,
104+ const paddle::optional<std::vector<int64_t >>& output_padding_offset_shape) {
105+ // if (output_padding_offset_shape) {
106+ // PD_THROW("speculative decoding is not supported in XPU.");
107+ // }
108+ int64_t bsz = cum_offsets_shape[0 ];
109+ int64_t dim_embed = x_shape[1 ];
110+ if (output_padding_offset_shape) {
111+ return {{-1 , dim_embed}};
112+ } else {
78113 int64_t bsz = cum_offsets_shape[0 ];
79- int64_t dim_embed = tmp_out_shape[1 ];
80114 return {{bsz, dim_embed}};
115+ }
81116}
82117
83118std::vector<paddle::DataType> GatherNextTokenInferDtype (
84- const paddle::DataType &tmp_out_dtype,
85- const paddle::DataType &cum_offsets_dtype,
86- const paddle::DataType &encoder_seq_lod_dtype,
87- const paddle::DataType &encoder_batch_map_dtype,
88- const paddle::DataType &decoder_batch_map_dtype,
89- const paddle::DataType &encoder_seq_lod_cpu_dtype,
90- const paddle::DataType &encoder_batch_map_cpu_dtype,
91- const paddle::DataType &decoder_batch_map_cpu_dtype,
92- const paddle::DataType &enc_batch_tensor_dtype,
93- const paddle::DataType &dec_batch_tensor_dtype,
94- const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
95- return {tmp_out_dtype};
119+ const paddle::DataType& x_dtype,
120+ const paddle::DataType& cum_offsets_dtype,
121+ const paddle::DataType& encoder_seq_lod_dtype,
122+ const paddle::DataType& encoder_batch_map_dtype,
123+ const paddle::DataType& decoder_batch_map_dtype,
124+ const paddle::DataType& encoder_seq_lod_cpu_dtype,
125+ const paddle::DataType& encoder_batch_map_cpu_dtype,
126+ const paddle::DataType& decoder_batch_map_cpu_dtype,
127+ const paddle::DataType& len_info_cpu_dtype,
128+ const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
129+ return {x_dtype};
96130}
97131
98132PD_BUILD_OP (gather_next_token)
99- .Inputs({" tmp_out" , " cum_offsets" , " encoder_seq_lod" , " encoder_batch_map" ,
100- " decoder_batch_map" , " encoder_seq_lod_cpu" ,
101- " encoder_batch_map_cpu" , " decoder_batch_map_cpu" ,
102- " enc_batch_tensor" , " dec_batch_tensor" ,
133+ .Inputs({" x" ,
134+ " cum_offsets" ,
135+ " encoder_seq_lod" ,
136+ " encoder_batch_map" ,
137+ " decoder_batch_map" ,
138+ " encoder_seq_lod_cpu" ,
139+ " encoder_batch_map_cpu" ,
140+ " decoder_batch_map_cpu" ,
141+ " len_info_cpu" ,
103142 paddle::Optional (" output_padding_offset" )})
104143 .Outputs({" out" })
105- .Attrs({" max_input_length: int" })
106144 .SetKernelFn(PD_KERNEL(GatherNextToken))
107145 .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
108- .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
146+ .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
0 commit comments