Skip to content

Commit 8e4fc6d

Browse files
tamarPaltamarPal
authored andcommitted
sycl: add ROLL operation support (ggml-org#16665)
* sycl: add ROLL operation support - Implement ggml_sycl_roll function for F32 tensors - Add multi-axis roll operation with SYCL kernel - Support all 4 tensor dimensions with proper shift normalization - Add roll.cpp and roll.hpp to SYCL backend - Update backend dispatch and supports_op for GGML_OP_ROLL - Tests: 17662/17662 pass with identical CPU reference results * fix: remove trailing whitespace from roll.cpp - Fix EditorConfig violations in ggml/src/ggml-sycl/roll.cpp - Remove trailing spaces from lines 6, 11, 28, 47, 58, 60 * ci: retrigger * sycl: remove wait() calls from ROLL operation * fix: editorconfig — LF endings + final newline for roll.hpp --------- Co-authored-by: tamarPal <[email protected]>
1 parent c6c8195 commit 8e4fc6d

File tree

4 files changed

+148
-0
lines changed

4 files changed

+148
-0
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "pad.hpp"
3333
#include "quantize.hpp"
3434
#include "quants.hpp"
35+
#include "roll.hpp"
3536
#include "rope.hpp"
3637
#include "set_rows.hpp"
3738
#include "softmax.hpp"

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3921,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
39213921
case GGML_OP_GATED_LINEAR_ATTN:
39223922
ggml_sycl_op_gated_linear_attn(ctx, dst);
39233923
break;
3924+
case GGML_OP_ROLL:
3925+
ggml_sycl_roll(ctx, dst);
3926+
break;
39243927
case GGML_OP_ARANGE:
39253928
ggml_sycl_arange(ctx, dst);
39263929
break;
@@ -4599,6 +4602,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45994602
case GGML_OP_RWKV_WKV7:
46004603
case GGML_OP_GATED_LINEAR_ATTN:
46014604
return true;
4605+
case GGML_OP_ROLL:
4606+
return op->type == GGML_TYPE_F32;
46024607
case GGML_OP_ARANGE:
46034608
return op->type == GGML_TYPE_F32;
46044609
default:

ggml/src/ggml-sycl/roll.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include "roll.hpp"
2+
#include "common.hpp"
3+
4+
using namespace sycl;
5+
6+
static inline int wrap_add(int i, int shift, int n) {
7+
8+
int s = i + shift;
9+
return (s >= n) ? (s - n) : s;
10+
}
11+
12+
static void kernel_roll_fused_i0_i1(
13+
queue &q,
14+
const float *src_d,
15+
float *dst_d,
16+
int ne0, int ne1, int ne2, int ne3,
17+
int sh0, int sh1, int sh2, int sh3)
18+
{
19+
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
20+
21+
22+
const int stride1 = ne0;
23+
const int stride2 = ne0 * ne1;
24+
const int stride3 = ne0 * ne1 * ne2;
25+
26+
27+
const int shNe0 = (ne0 - sh0) % ne0;
28+
const int shNe1 = (ne1 - sh1) % ne1;
29+
const int shNe2 = (ne2 - sh2) % ne2;
30+
const int shNe3 = (ne3 - sh3) % ne3;
31+
32+
33+
const size_t g0 = (size_t) ne3;
34+
const size_t g1 = (size_t) ne2;
35+
const size_t g2 = (size_t) (ne1 * ne0);
36+
37+
const range<3> global{ g0, g1, g2 };
38+
39+
q.submit([&](handler &h) {
40+
h.parallel_for(global, [=](id<3> idx) {
41+
const int i3 = (int) idx[0];
42+
const int i2 = (int) idx[1];
43+
44+
const int fused = (int) idx[2];
45+
const int i1 = fused / ne0;
46+
const int i0 = fused - i1 * ne0; // fused % ne0
47+
48+
49+
const int idx_dst = i0
50+
+ i1 * stride1
51+
+ i2 * stride2
52+
+ i3 * stride3;
53+
54+
55+
const int s0 = wrap_add(i0, shNe0, ne0);
56+
const int s1 = wrap_add(i1, shNe1, ne1);
57+
const int s2 = wrap_add(i2, shNe2, ne2);
58+
const int s3 = wrap_add(i3, shNe3, ne3);
59+
60+
const int idx_src = s0
61+
+ s1 * stride1
62+
+ s2 * stride2
63+
+ s3 * stride3;
64+
65+
dst_d[idx_dst] = src_d[idx_src];
66+
});
67+
});
68+
}
69+
70+
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
71+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
72+
73+
const ggml_tensor *src = dst->src[0];
74+
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
75+
76+
const int ne0 = (int) dst->ne[0];
77+
const int ne1 = (int) dst->ne[1];
78+
const int ne2 = (int) dst->ne[2];
79+
const int ne3 = (int) dst->ne[3];
80+
81+
const int32_t *params = (const int32_t *) dst->op_params;
82+
int shift0 = params[0];
83+
int shift1 = params[1];
84+
int shift2 = params[2];
85+
int shift3 = params[3];
86+
87+
88+
if ((shift0 | shift1 | shift2 | shift3) == 0) {
89+
const size_t nb = ggml_nbytes(src);
90+
queue *q = ctx.stream();
91+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
92+
return;
93+
}
94+
95+
auto norm = [](int sh, int n) -> int {
96+
if (n <= 0) return 0;
97+
sh %= n;
98+
if (sh < 0) sh += n;
99+
return sh;
100+
};
101+
shift0 = norm(shift0, ne0);
102+
shift1 = norm(shift1, ne1);
103+
shift2 = norm(shift2, ne2);
104+
shift3 = norm(shift3, ne3);
105+
106+
try {
107+
queue *q = ctx.stream();
108+
109+
const float *src_d = (const float *) src->data;
110+
float *dst_d = (float *) dst->data;
111+
GGML_ASSERT(src_d && dst_d);
112+
113+
kernel_roll_fused_i0_i1(
114+
*q, src_d, dst_d,
115+
ne0, ne1, ne2, ne3,
116+
shift0, shift1, shift2, shift3
117+
);
118+
} catch (const std::exception &e) {
119+
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
120+
throw;
121+
}
122+
}

ggml/src/ggml-sycl/roll.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//
2+
// MIT license
3+
// Copyright (C) 2024 Intel Corporation
4+
// SPDX-License-Identifier: MIT
5+
//
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
13+
#ifndef GGML_SYCL_ROLL_HPP
14+
#define GGML_SYCL_ROLL_HPP
15+
16+
#include "common.hpp"
17+
18+
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
19+
20+
#endif // GGML_SYCL_ROLL_HPP

0 commit comments

Comments
 (0)