Skip to content

Commit

Permalink
[GPU] Add subsequent reshapes optimization and dynamic paddings suppo…
Browse files Browse the repository at this point in the history
…rt for RoPE and PagedAttention
  • Loading branch information
sshlyapn committed Oct 30, 2024
1 parent 884ac4a commit c1e73e9
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -659,23 +659,34 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
if (user_info.first && user_info.first->is_type<reshape>()) {
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
auto reshape_axis = crop_axis;
if (reshape_mode == reshape::reshape_mode::base) {
user_info.second.data_padding._dynamic_dims_mask = dyn_pad_sizes;
auto reshape_ps = user_info.second.get_partial_shape();
auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length();

auto mul = 1;
reshape_axis = reshape_ps.size() - 1;
for (size_t i = reshape_ps.size(); i > 1; i--) {
if (reshape_ps[i - 1].is_dynamic() || mul == crop_dim_val)
break;

mul *= reshape_ps[i - 1].get_length();
reshape_axis = i - 1;
}
} else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;

auto reshape_axis = crop_axis;
for (size_t i = 0; i < output_pattern.size(); i++) {
if (output_pattern[i] <= static_cast<int64_t>(reshape_axis)) {
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
}
}

padding::DynamicDimsMask dyn_pad_mask;
dyn_pad_mask[reshape_axis] = 1;
user_info.second.data_padding._dynamic_dims_mask = dyn_pad_mask;
}

auto reshape_dyn_pad_mask = padding::DynamicDimsMask();
reshape_dyn_pad_mask[reshape_axis] = 1;
user_info.second.data_padding._dynamic_dims_mask = reshape_dyn_pad_mask;
}
return;
}
Expand Down Expand Up @@ -703,13 +714,36 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
if (reshape_mode == reshape::reshape_mode::base) {
auto reshape_rank = user_info.second.get_partial_shape().size();
auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1];
if (lower_sizes[crop_axis])
lower_sizes[crop_axis] /= reshape_last_dim;
if (upper_sizes[crop_axis])
upper_sizes[crop_axis] /= reshape_last_dim;
user_info.second.data_padding = padding(lower_sizes, upper_sizes, dyn_pad_sizes);
auto reshape_ps = user_info.second.get_partial_shape();
auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length();

auto divider = 1;
auto reshape_axis = reshape_ps.size();
for (size_t i = reshape_ps.size(); i > 1; i--) {
const auto& dim_value = reshape_ps[i - 1].get_length();
if (divider * dim_value == crop_dim_val)
break;

divider *= dim_value;
reshape_axis = i - 1;
}
reshape_axis -= 1;

const auto output_rank = std::max(reshape_ps.size(), static_cast<size_t>(4));
std::vector<int32_t> reshape_lower_sizes(output_rank, 0);
std::vector<int32_t> reshape_upper_sizes(output_rank, 0);
padding::DynamicDimsMask reshape_dyn_pad_mask;

reshape_lower_sizes[reshape_axis] = lower_sizes[crop_axis];
reshape_upper_sizes[reshape_axis] = upper_sizes[crop_axis];
reshape_dyn_pad_mask[reshape_axis] = 1;

if (reshape_lower_sizes[reshape_axis])
reshape_lower_sizes[reshape_axis] /= divider;
if (reshape_upper_sizes[reshape_axis])
reshape_upper_sizes[reshape_axis] /= divider;

user_info.second.data_padding = padding(reshape_lower_sizes, reshape_upper_sizes, reshape_dyn_pad_mask);
} else {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;
Expand Down
15 changes: 9 additions & 6 deletions src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
return false;

// TODO: If user is RoPE or MVN and dynamic padding exists, ouput padding propagation is not supported in the base mode
if (get_users().size() == 1 && (get_users().front()->is_type<rope>() || get_users().front()->is_type<mvn>()))
if (get_users().size() == 1 && get_users().front()->is_type<mvn>())
return false;

auto axis = input().as<crop>().get_primitive()->axis;
Expand All @@ -73,14 +73,17 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
const auto& output_pshape = prim->output_partial_shape;
// TODO: If the reshape's output shape is non constant, issue occurs
// during shape inference due to execution order at runtime
if ((output_pshape.size() != input_rank + 1) || prim->output_pattern.empty())
if (prim->output_pattern.empty())
return false;

// Iteratively check the total product of all static innermost dimensions
// until the crop dimension value matches or the first dynamic dimension is encountered
int64_t mul = 1;
for (size_t i = input_rank - 1; i < output_pshape.size() ; i++) {
if (output_pshape[i].is_dynamic())
return false;
mul *= output_pshape[i].get_length();
for (size_t i = output_pshape.size(); i > 1 ; i--) {
if (output_pshape[i - 1].is_dynamic() || mul == input_last_dim_val)
break;

mul *= output_pshape[i - 1].get_length();
}
if (input_last_dim_val != mul)
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@ KERNEL(pa_kv_cache_update)(
const uint seq_block_idx = block_indices_begins[seq_idx] + seq_len / PAGED_ATTENTION_BLOCK_SIZE;
const uint block_idx = block_indices[seq_block_idx];

uint key_value_in_offset = seq_idx * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE;
uint key_in_offset = INPUT0_PAD_BEFORE_FEATURE_NUM +
seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
head_idx * HEAD_SIZE;
uint value_in_offset = INPUT1_PAD_BEFORE_FEATURE_NUM +
seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM) +
head_idx * HEAD_SIZE;

uint key_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block;

uint value_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block * HEAD_SIZE;

#define READ_BLOCK_SIZE GENERATE_STAGE_BLOCK_SIZE
for (uint head_idx_index = 0; head_idx_index < HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
Expand All @@ -56,7 +60,7 @@ KERNEL(pa_kv_cache_update)(
#endif
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
Expand All @@ -83,8 +87,13 @@ KERNEL(pa_kv_cache_update)(

const uint token_start_pos = (past_len + block_start_pos - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE;

uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE +
head_idx * HEAD_SIZE;
uint key_in_offset = INPUT0_PAD_BEFORE_FEATURE_NUM +
block_start_pos * (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM) +
head_idx * HEAD_SIZE;

uint value_in_offset = INPUT1_PAD_BEFORE_FEATURE_NUM +
block_start_pos * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM) +
head_idx * HEAD_SIZE;

const uint current_block_idx = (past_len + block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE;

Expand All @@ -106,14 +115,14 @@ KERNEL(pa_kv_cache_update)(
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
key_cache_data[key_offset] = input_data[i];
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
Expand All @@ -126,14 +135,14 @@ KERNEL(pa_kv_cache_update)(
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
key_cache_data[key_offset] = input_data[i];
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
Expand All @@ -146,14 +155,14 @@ KERNEL(pa_kv_cache_update)(
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
key_cache_data[key_offset] = input_data[i];
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
Expand All @@ -166,22 +175,23 @@ KERNEL(pa_kv_cache_update)(
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
key_cache_data[key_offset] = input_data;
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
value_cache_data[value_offset] = input_data;
}
}

key_value_in_offset += KV_HEADS_NUM * HEAD_SIZE;
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
key_out_offset += 1;
value_out_offset += HEAD_SIZE;
}
Expand All @@ -194,22 +204,23 @@ KERNEL(pa_kv_cache_update)(
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
#define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE)

DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index);
DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE;
key_cache_data[key_offset] = input_data;
}

input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index);
input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index);

unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) {
uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i;
value_cache_data[value_offset] = input_data;
}
}

key_value_in_offset += KV_HEADS_NUM * HEAD_SIZE;
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
key_out_offset += 1;
value_out_offset += HEAD_SIZE;
}
Expand Down
30 changes: 11 additions & 19 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ KERNEL(rope_ref)(
uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0;
uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 : 0;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0);

input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1)
+ SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b);
#else
uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif

uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0;
uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0;
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0);
Expand Down Expand Up @@ -69,14 +66,11 @@ KERNEL(rope_ref)(
const uint h = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, p, h * HEAD_SIZE, 0);

input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + p + 1)
+ SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + p);
#else
uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0;
uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0;
Expand Down Expand Up @@ -119,15 +113,13 @@ KERNEL(rope_ref)(
const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, h, p, 0);

input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + h + 1)
+ SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + h);
#elif ENABLE_TRANSPOSE
uint input_idx = GET_DATA_INDEX(TRANSPOSED_INPUT0, b, h, p, 0);
#if ENABLE_TRANSPOSE
uint input_idx = INPUT0_GET_INDEX(b, p, h, 0);
#else
uint input_idx = INPUT0_GET_INDEX(b, h, p, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
Expand Down
Loading

0 comments on commit c1e73e9

Please sign in to comment.