diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 47 |
1 files changed, 45 insertions, 2 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7e92c519..654d3632 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 @@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} template<int qk, int qr, dequantize_kernel_t dq> static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { @@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); } +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k); +} + template<typename T> static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { @@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi( const int64_t ne02 = src0->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - GGML_ASSERT(ne01 + n_past == ne00); + //GGML_ASSERT(ne01 + n_past == ne00); GGML_ASSERT(n_head == ne02); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float min = ((float *) dst->op_params)[0]; + const float max = ((float *) dst->op_params)[1]; + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); @@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_SCALE: func = ggml_cuda_scale; break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_clamp; + break; case GGML_OP_CPY: func = ggml_cuda_cpy; break; |