Skip to content

Commit 58ecf6b

Browse files
authored
cuda: unary ops as float + de-duplicate (#1130)
1 parent ff90529 commit 58ecf6b

File tree

2 files changed

+132
-567
lines changed

2 files changed

+132
-567
lines changed

src/ggml-cuda/clamp.cu

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
#include "clamp.cuh"
22

3+
static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
4+
return fminf(fmaxf(x, min), max);
5+
}
6+
37
template <class T>
4-
static __global__ void op_clamp(const T * x, T * dst, const T min, const T max, const int k) {
8+
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
59
const int i = blockDim.x*blockIdx.x + threadIdx.x;
610

711
if (i >= k) {
812
return;
913
}
1014

11-
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
15+
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
1216
}
1317

1418
template <class T>
1519
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
1620
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
17-
op_clamp<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
21+
op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
1822
}
1923

2024

0 commit comments

Comments
 (0)