Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sycl-exp : unify rope neox/norm #7919

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 94 additions & 122 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8826,7 +8826,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
}

struct rope_corr_dims {
float v[4];
float v[2];
};

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
Expand All @@ -8850,29 +8850,38 @@ static void rope_yarn(
}

// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos>
static void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
,
template<typename T, bool has_ff>
static void rope_norm(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));

if (col >= ncols) {
if (i0 >= ne0) {
return;
}

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int i = row*ncols + col;

if (i0 >= n_dims) {
const int i = row*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
const float theta_base = pos[i2]*sycl::pow(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + 1];
Expand All @@ -8881,45 +8890,40 @@ static void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_pos, bool has_freq_facs>
static void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
template <typename T, bool has_ff>
static void rope_neox(const T *x, T *dst, int ne0, int n_dims,
const int32_t *pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor,
rope_corr_dims corr_dims, float theta_scale,
const float *freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));

if (col >= ncols) {
if (i0 >= ne0) {
return;
}

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int ib = col / n_dims;
const int ic = col % n_dims;

if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;

const float theta_base =
p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
const float theta_base = pos[i2]*sycl::pow(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
Expand Down Expand Up @@ -12375,15 +12379,18 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
}

template <typename T>
static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
static void rope_norm_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor,
rope_corr_dims corr_dims, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0);
rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
if (pos == nullptr) {
const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, n_blocks_x, nr);

const float theta_scale = sycl::pow(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
Expand All @@ -12395,8 +12402,8 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims,
rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
} else {
Expand All @@ -12411,70 +12418,46 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims,
rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
}
}

template <typename T>
static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
static void rope_neox_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale,
int p_delta_rows, float freq_base, float ext_factor,
float attn_factor, rope_corr_dims corr_dims,
const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, n_blocks_x, nr);

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
const float theta_scale = sycl::pow(freq_base, -2.0f/n_dims);

if (pos == nullptr) {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors,
item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors,
item_ct1);
});
}
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
} else {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
}
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
}
}

Expand Down Expand Up @@ -12592,8 +12575,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const float m0 = sycl::pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = sycl::pow(2.0f, -(max_bias / 2.0f) / n_head_log2);

const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
if (n_local_scratch*sizeof(float) < local_mem_size) {
Expand Down Expand Up @@ -14005,8 +13988,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
const int64_t nr = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
Expand All @@ -14023,27 +14005,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));

const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd;
}

const bool is_neox = mode & 2;

#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")

if (is_neox) {
pos = (const int32_t *) src1_dd;
const int32_t * pos = (const int32_t *) src1_dd;

if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
const float * freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}

rope_corr_dims corr_dims;
Expand All @@ -14053,27 +14021,27 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
ne00, n_dims, nrows, pos, freq_scale, ne01,
ne00, n_dims, nr, pos, freq_scale, ne01,
freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, main_stream);
} else {
GGML_ASSERT(false);
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
rope_norm_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream);
rope_norm_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream);
} else {
GGML_ASSERT(false);
}
Expand Down Expand Up @@ -17267,7 +17235,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
int dim = op->op_params[0];
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
} break;
case GGML_OP_ROPE:
{
return ggml_is_contiguous(op->src[0]);
} break;
case GGML_OP_DUP:
case GGML_OP_NONE:
Expand All @@ -17287,7 +17260,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
joeatodd marked this conversation as resolved.
Show resolved Hide resolved
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
Expand Down
Loading