summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu42
-rw-r--r--ggml/src/ggml-cuda/cpy.cu50
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh44
5 files changed, 155 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index d75b219b..d7e9c529 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -376,6 +376,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_0> {
+ static constexpr int qk = QK6_0;
+ static constexpr int qr = QR6_0;
+ static constexpr int qi = QI6_0;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index c74b030b..7089a6df 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -129,6 +129,36 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_q6_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q6_0 * x = (const block_q6_0 *)vx + ib;
+ const float d = __half2float(x->d);
+ const float dm = -32*d;
+
+ const uint8_t * qs = x->qs + 4*il;
+ const uint8_t * qh = x->qh + 4*(il%2);
+
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t h = qh[l] >> 4*(il/2);
+ y[l+ 0] = d * ((qs[l] & 0xF) | ((h << 4) & 0x30)) + dm;
+ y[l+16] = d * ((qs[l] >> 4) | ((h << 2) & 0x30)) + dm;
+ }
+}
+
//================================== k-quants
template<typename dst_t>
@@ -768,6 +798,14 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n
}
template<typename dst_t>
+static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q6_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
@@ -1004,6 +1042,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q6_0:
+ return dequantize_row_q6_0_cuda;
case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
@@ -1074,6 +1114,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q6_0:
+ return dequantize_row_q6_0_cuda;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 1a84a4cb..0b269a86 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -221,6 +221,41 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh));
}
+static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q6_0 * dsti = (block_q6_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK6_0; ++j) {
+ const float v = xi[j];
+ const float av = fabsf(xi[j]);
+ if (amax < av) {
+ amax = av;
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -32;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+ memset(dsti->qh, 0, QK6_0/4);
+
+ for (int j = 0; j < QK6_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
+ const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
+
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
+ }
+}
+
static __device__ const int8_t iq4nl_index[241] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
@@ -397,6 +432,17 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
+static void ggml_cpy_f32_q6_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK6_0 == 0);
+ const int num_blocks = ne / QK6_0;
+ cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -466,6 +512,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
+ ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
@@ -505,6 +553,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 5f932fef..15e8fb5a 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -9,6 +9,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+ type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
@@ -34,6 +35,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
+ type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
@@ -232,6 +234,13 @@ static void mul_mat_vec_q5_1_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+static void mul_mat_vec_q6_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -384,6 +393,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_Q6_0:
+ mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index b1b465a3..7baabb7a 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -48,6 +48,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
}
+#define VDR_Q6_0_Q8_1_MMVQ 2
+#define VDR_Q6_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q6_0_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const float & d6, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = ((vl[i] >> 0) & 0x0F0F0F0F) | ((vh[i] << 4) & 0x30303030);
+ const int vi1 = ((vl[i] >> 4) & 0x0F0F0F0F) | ((vh[i] << 2) & 0x30303030);
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 8 from each quant value
+ return d6 * (sumi * ds8f.x - (32.f*vdr/QI6_0) * ds8f.y);
+}
+
#define VDR_Q4_1_Q8_1_MMVQ 2
#define VDR_Q4_1_Q8_1_MMQ 4
@@ -549,6 +573,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
}
+static __device__ __forceinline__ float vec_dot_q6_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q6_0 * bq6_0 = (const block_q6_0 *) vbq + kbx;
+
+ int vl[VDR_Q6_0_Q8_1_MMVQ];
+ int vh[VDR_Q6_0_Q8_1_MMVQ];
+ int u[2*VDR_Q6_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q6_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b2(bq6_0->qs, iqs + i);
+ vh[i] = get_int_b2(bq6_0->qh, i) >> 4*(iqs/2);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI6_0);
+ }
+
+ return vec_dot_q6_0_q8_1_impl<VDR_Q6_0_Q8_1_MMVQ>(vl, vh, u, bq6_0->d, bq8_1->ds);
+}
+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {