Skip to content

Commit

Permalink
Add support for sqrt on CUDA (#7953)
Browse files Browse the repository at this point in the history
* cuda sqrt support

* enable cuda in pca

* fix comments in pca

* add test

* add sqrt to ggml_backend_cuda_supports_op

* fix test

* new line

* Use F32 sqrtf instead of F64 sqrt

Co-authored-by: Johannes Gäßler <[email protected]>

---------

Co-authored-by: Johannes Gäßler <[email protected]>
  • Loading branch information
calvin-laurenson and JohannesGaessler authored Jun 16, 2024
1 parent 19b7a83 commit 43b35e3
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/cvector-generator/pca.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ struct pca_model {
struct ggml_tensor * dev_eigenvector;

pca_model(struct ggml_tensor * t_input) {
// TODO: enable GPU support when support for GGML_OP_SQRT is added
// #ifdef GGML_USE_CUDA
// fprintf(stderr, "%s: using CUDA backend\n", __func__);
// backend = ggml_backend_cuda_init(0); // init device 0
// if (!backend) {
// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
// }
// #endif
#ifdef GGML_USE_CUDA
fprintf(stderr, "%s: using CUDA backend\n", __func__);
backend = ggml_backend_cuda_init(0); // init device 0
if (!backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
#endif

// TODO: enable Metal support when support for GGML_OP_SQRT is added
// #ifdef GGML_USE_METAL
// fprintf(stderr, "%s: using Metal backend\n", __func__);
// backend = ggml_backend_metal_init();
Expand Down
4 changes: 4 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SQR:
ggml_cuda_op_sqr(ctx, dst);
break;
case GGML_OP_SQRT:
ggml_cuda_op_sqrt(ctx, dst);
break;
case GGML_OP_CLAMP:
ggml_cuda_op_clamp(ctx, dst);
break;
Expand Down Expand Up @@ -2830,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
Expand Down
28 changes: 28 additions & 0 deletions ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}

static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = sqrtf(x[i]);
}

static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
Expand Down Expand Up @@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
Expand Down Expand Up @@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
3 changes: 3 additions & 0 deletions ggml-cuda/unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256

void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

Expand All @@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
28 changes: 28 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,33 @@ struct test_sqr : public test_case {
}
};

// GGML_OP_SQRT
struct test_sqrt : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

test_sqrt(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10})
: type(type), ne(ne) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * out = ggml_sqrt(ctx, a);
return out;
}

void initialize_tensors(ggml_context * ctx) override {
// fill with positive values
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, 0.0f, 100.0f);
}
}
};

// GGML_OP_CLAMP
struct test_clamp : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}

test_cases.emplace_back(new test_sqr());
test_cases.emplace_back(new test_sqrt());
test_cases.emplace_back(new test_clamp());

test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
Expand Down

0 comments on commit 43b35e3

Please sign in to comment.