summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu186
1 files changed, 169 insertions, 17 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 9019a849..1200d1c8 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -7,6 +7,7 @@
#include <stdio.h>
#include <atomic>
#include <assert.h>
+#include <float.h>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@@ -4559,6 +4560,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = xi[j];
+ amax = fmaxf(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = xi[j]*id;
+
+ dsti->qs[j] = roundf(x0);
+ }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_0; ++j) {
+ const float v = xi[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+ float vmin = FLT_MAX;
+ float vmax = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; ++j) {
+ const float v = xi[j];
+
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->dm.x = d;
+ dsti->dm.y = vmin;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (xi[0 + j] - vmin)*id;
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int i02 = i / (ne00*ne01);
+ const int i01 = (i - i02*ne01*ne00) / ne00;
+ const int i00 = (i - i02*ne01*ne00 - i01*ne00);
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+ const int i12 = i / (ne10*ne11);
+ const int i11 = (i - i12*ne10*ne11) / ne10;
+ const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@@ -5737,6 +5848,39 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
+static void ggml_cpy_f32_q8_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK8_0 == 0);
+ const int num_blocks = ne / QK8_0;
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_0 == 0);
+ const int num_blocks = ne / QK4_0;
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_1 == 0);
+ const int num_blocks = ne / QK4_1;
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -6093,20 +6237,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
- int64_t i1_diff = i1_high - i1_low;
+ const int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*ne0/bs) {
+ if (nb0 == ts && nb1 == ts*(ne0/bs)) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
}
if (nb0 == ts) {
- return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
+ return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream);
}
+ GGML_ASSERT(bs == 1 && "TODO: implement bs != 1");
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0);
// pretend the row is a matrix with cols=1
- cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream);
if (r != cudaSuccess) { return r; }
}
return cudaSuccess;
@@ -6474,6 +6619,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
@@ -6533,7 +6680,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
size_t ash;
dfloat * src1_dfloat = nullptr; // dfloat == half
- bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ bool src1_convert_f16 =
+ src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
@@ -7103,10 +7251,9 @@ static void ggml_cuda_op_mul_mat(
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
-
const bool src1_is_contiguous = ggml_is_contiguous(src1);
- const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
- ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+
+ const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 > 1));
@@ -7231,7 +7378,7 @@ static void ggml_cuda_op_mul_mat(
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+ char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -7698,10 +7845,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
#ifdef GGML_CUDA_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else
- const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+ const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
#endif // GGML_CUDA_FORCE_DMMV
if (use_mul_mat_vec_q) {
+ // NOTE: this kernel does not support ggml_nrows(src1) > 1
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@@ -7770,14 +7918,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
- ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7788,6 +7939,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
}
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ // TODO: why do we pass dst as src1 here?
ggml_cuda_cpy(src0, dst, nullptr);
(void) src1;
}