summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-11-13 16:55:52 +0200
committerGitHub <noreply@github.com>2023-11-13 16:55:52 +0200
commit3d68f364f15778dc326f5024f2e5af1ad6dfddef (patch)
treec0c11d150ba56b4f646261790728622efa30d8a1 /ggml-cuda.cu
parentc049b37d7baf558944501705b91ac89b26ee3e41 (diff)
ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060)
ggml-ci
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu106
1 files changed, 104 insertions, 2 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 16340244..7be63925 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
*dsti = __float2half(*xi);
}
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+ const half * xi = (const half *) cxi;
+ half * dsti = (half *) cdsti;
+
+ *dsti = *xi;
+}
+
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
+static __global__ void im2col_f32_f16(
+ const float * x, half * dst,
+ int ofs0, int ofs1, int IW, int IH, int CHW,
+ int s0, int s1, int p0, int p1, int d0, int d1) {
+ const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
+ const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+
+ const int offset_dst =
+ (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
+ (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] = __float2half(0.0f);
+ } else {
+ const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
+ dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+ }
+}
+
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
+static void ggml_cpy_f16_f16_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}
+static void im2col_f32_f16_cuda(const float * x, half * dst,
+ int OH, int IW, int IH, int OW, int IC,
+ int KH, int KW, int N, int ofs0, int ofs1,
+ int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
+ dim3 block_nums(IC, OH, OW);
+ dim3 block_dims(N, KH, KW);
+ im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 256
@@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
}
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
-
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
@@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
(void) src1_dd;
}
+inline void ggml_cuda_op_im2col(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+ const int64_t N = src1->ne[is_2D ? 3 : 2];
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
+ const int64_t IW = src1->ne[0];
+
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
+ const int64_t KW = src0->ne[0];
+
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
+ const int64_t OW = dst->ne[1];
+
+ const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+ const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+ im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
+ OH, IW, IH, OW, IC, KH, KW, N,
+ ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+
+ (void) src0;
+ (void) src0_dd;
+}
+
inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+ ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
+void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
+}
+
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
return false;
}
+ if (tensor->op == GGML_OP_MUL_MAT) {
+ if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
+#endif
+ return false;
+ }
+ }
+
switch (tensor->op) {
case GGML_OP_REPEAT:
func = ggml_cuda_repeat;
@@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_ALIBI:
func = ggml_cuda_alibi;
break;
+ case GGML_OP_IM2COL:
+ func = ggml_cuda_im2col;
+ break;
default:
return false;
}