|  | 
|  | 1 | +#include "roll.hpp" | 
|  | 2 | +#include "common.hpp" | 
|  | 3 | + | 
|  | 4 | +using namespace sycl; | 
|  | 5 | + | 
|  | 6 | +static inline int wrap_add(int i, int shift, int n) { | 
|  | 7 | + | 
|  | 8 | +    int s = i + shift; | 
|  | 9 | +    return (s >= n) ? (s - n) : s; | 
|  | 10 | +} | 
|  | 11 | + | 
|  | 12 | +static void kernel_roll_fused_i0_i1( | 
|  | 13 | +    queue &q, | 
|  | 14 | +    const float *src_d, | 
|  | 15 | +    float *dst_d, | 
|  | 16 | +    int ne0, int ne1, int ne2, int ne3, | 
|  | 17 | +    int sh0, int sh1, int sh2, int sh3) | 
|  | 18 | +{ | 
|  | 19 | +    if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +    const int stride1 = ne0; | 
|  | 23 | +    const int stride2 = ne0 * ne1; | 
|  | 24 | +    const int stride3 = ne0 * ne1 * ne2; | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +    const int shNe0 = (ne0 - sh0) % ne0; | 
|  | 28 | +    const int shNe1 = (ne1 - sh1) % ne1; | 
|  | 29 | +    const int shNe2 = (ne2 - sh2) % ne2; | 
|  | 30 | +    const int shNe3 = (ne3 - sh3) % ne3; | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +    const size_t g0 = (size_t) ne3; | 
|  | 34 | +    const size_t g1 = (size_t) ne2; | 
|  | 35 | +    const size_t g2 = (size_t) (ne1 * ne0); | 
|  | 36 | + | 
|  | 37 | +    const range<3> global{ g0, g1, g2 }; | 
|  | 38 | + | 
|  | 39 | +    q.submit([&](handler &h) { | 
|  | 40 | +        h.parallel_for(global, [=](id<3> idx) { | 
|  | 41 | +            const int i3 = (int) idx[0]; | 
|  | 42 | +            const int i2 = (int) idx[1]; | 
|  | 43 | + | 
|  | 44 | +            const int fused = (int) idx[2]; | 
|  | 45 | +            const int i1 = fused / ne0; | 
|  | 46 | +            const int i0 = fused - i1 * ne0;  // fused % ne0 | 
|  | 47 | + | 
|  | 48 | + | 
|  | 49 | +            const int idx_dst = i0 | 
|  | 50 | +                              + i1 * stride1 | 
|  | 51 | +                              + i2 * stride2 | 
|  | 52 | +                              + i3 * stride3; | 
|  | 53 | + | 
|  | 54 | + | 
|  | 55 | +            const int s0 = wrap_add(i0, shNe0, ne0); | 
|  | 56 | +            const int s1 = wrap_add(i1, shNe1, ne1); | 
|  | 57 | +            const int s2 = wrap_add(i2, shNe2, ne2); | 
|  | 58 | +            const int s3 = wrap_add(i3, shNe3, ne3); | 
|  | 59 | + | 
|  | 60 | +            const int idx_src = s0 | 
|  | 61 | +                              + s1 * stride1 | 
|  | 62 | +                              + s2 * stride2 | 
|  | 63 | +                              + s3 * stride3; | 
|  | 64 | + | 
|  | 65 | +            dst_d[idx_dst] = src_d[idx_src]; | 
|  | 66 | +        }); | 
|  | 67 | +    }); | 
|  | 68 | +} | 
|  | 69 | + | 
|  | 70 | +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { | 
|  | 71 | +    GGML_ASSERT(dst->type == GGML_TYPE_F32); | 
|  | 72 | + | 
|  | 73 | +    const ggml_tensor *src = dst->src[0]; | 
|  | 74 | +    GGML_ASSERT(src && src->type == GGML_TYPE_F32); | 
|  | 75 | + | 
|  | 76 | +    const int ne0 = (int) dst->ne[0]; | 
|  | 77 | +    const int ne1 = (int) dst->ne[1]; | 
|  | 78 | +    const int ne2 = (int) dst->ne[2]; | 
|  | 79 | +    const int ne3 = (int) dst->ne[3]; | 
|  | 80 | + | 
|  | 81 | +    const int32_t *params = (const int32_t *) dst->op_params; | 
|  | 82 | +    int shift0 = params[0]; | 
|  | 83 | +    int shift1 = params[1]; | 
|  | 84 | +    int shift2 = params[2]; | 
|  | 85 | +    int shift3 = params[3]; | 
|  | 86 | + | 
|  | 87 | + | 
|  | 88 | +    if ((shift0 | shift1 | shift2 | shift3) == 0) { | 
|  | 89 | +        const size_t nb = ggml_nbytes(src); | 
|  | 90 | +        queue *q = ctx.stream(); | 
|  | 91 | +        SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); | 
|  | 92 | +        return; | 
|  | 93 | +    } | 
|  | 94 | + | 
|  | 95 | +    auto norm = [](int sh, int n) -> int { | 
|  | 96 | +        if (n <= 0) return 0; | 
|  | 97 | +        sh %= n; | 
|  | 98 | +        if (sh < 0) sh += n; | 
|  | 99 | +        return sh; | 
|  | 100 | +    }; | 
|  | 101 | +    shift0 = norm(shift0, ne0); | 
|  | 102 | +    shift1 = norm(shift1, ne1); | 
|  | 103 | +    shift2 = norm(shift2, ne2); | 
|  | 104 | +    shift3 = norm(shift3, ne3); | 
|  | 105 | + | 
|  | 106 | +    try { | 
|  | 107 | +        queue *q = ctx.stream(); | 
|  | 108 | + | 
|  | 109 | +        const float *src_d = (const float *) src->data; | 
|  | 110 | +        float *dst_d = (float *) dst->data; | 
|  | 111 | +        GGML_ASSERT(src_d && dst_d); | 
|  | 112 | + | 
|  | 113 | +        kernel_roll_fused_i0_i1( | 
|  | 114 | +            *q, src_d, dst_d, | 
|  | 115 | +            ne0, ne1, ne2, ne3, | 
|  | 116 | +            shift0, shift1, shift2, shift3 | 
|  | 117 | +        ); | 
|  | 118 | +    } catch (const std::exception &e) { | 
|  | 119 | +        std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); | 
|  | 120 | +        throw; | 
|  | 121 | +    } | 
|  | 122 | +} | 
0 commit comments