summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu128
-rw-r--r--ggml/src/ggml-cuda/dmmv.cu264
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu38
-rw-r--r--ggml/src/ggml-cuda/mmvq.cuh1
5 files changed, 437 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index a04a1929..896ba0df 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -565,6 +565,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 5afe8c74..17604f1c 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -333,6 +333,101 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
+inline __device__ int nearest_int(float fval) {
+ assert(fval <= 4194303.f);
+ float val = fval + 12582912.f;
+ int i; memcpy(&i, &val, sizeof(int));
+ return (i & 0x007fffff) - 0x00400000;
+}
+
+float __device__ __forceinline__ trellis_next(uint32_t& val) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ uint32_t s;
+ const half * h = (const half *)&s;
+ val = ka*val + kb;
+ s = (val & kmask) ^ km32;
+ return (float)(h[0]+h[1]);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const float *)cx;
+ const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ uint32_t idx = ql[ib] + 4096;
+ const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * trellis_next(idx);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const float *)cx;
+ const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ uint32_t idx = ql[ib] + 4096;
+ const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
+ uint8_t mask = 1 << (ib/4);
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const float * dptr = (const float *)((const char *)vx + row * row_size);
+ float scale = dptr[0] * 31.75f * 1.01f;
+ float row_av = dptr[1];
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ constexpr int kNumGroups = 64;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock;
+ const uint8_t * qh = ql + kNumGroups;
+ const int ib32 = ib/4;
+ const int ig = ib%4;
+ const int jj = ib32*8 + 2*ig;
+ uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096;
+ uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset;
+ uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset;
+ int ls = ((shb[ib32] & 0xff) >> 1) - 64;
+ const float dl = scale * ls;
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = dl * trellis_next(idx1) + row_av;
+ y[j+4] = dl * trellis_next(idx2) + row_av;
+ }
+}
+
template<typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@@ -969,6 +1064,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
}
template<typename dst_t>
+static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
+}
+
+template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
@@ -1230,6 +1346,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_KT:
+ return dequantize_row_iq2_kt_cuda;
+ case GGML_TYPE_IQ3_KT:
+ return dequantize_row_iq3_kt_cuda;
+ case GGML_TYPE_IQ4_KT:
+ return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
@@ -1303,6 +1425,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_KT:
+ return dequantize_row_iq2_kt_cuda;
+ case GGML_TYPE_IQ3_KT:
+ return dequantize_row_iq3_kt_cuda;
+ case GGML_TYPE_IQ4_KT:
+ return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
index 12738240..50e6458d 100644
--- a/ggml/src/ggml-cuda/dmmv.cu
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -1,3 +1,10 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
#include "dmmv.cuh"
#include "dequantize.cuh"
#include "convert.cuh"
@@ -8,6 +15,220 @@
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
+static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ val = ka*val + kb;
+ return (val & kmask) ^ km32;
+}
+
+static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
+ const half * h = (const half *)s;
+ s[0] = trellis_next(val1);
+ s[1] = trellis_next(val1);
+ s[2] = trellis_next(val2);
+ s[3] = trellis_next(val2);
+#ifdef GGML_CUDA_F16
+ bdot1 = __hfma2(y[ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
+ bdot2 = __hfma2(y[64], {h[4]+h[5], h[6]+h[7]}, bdot2);
+#else
+ bdot1.x += y[ 0].x * (float)(h[0] + h[1]);
+ bdot1.y += y[ 0].y * (float)(h[2] + h[3]);
+ bdot2.x += y[64].x * (float)(h[4] + h[5]);
+ bdot2.y += y[64].y * (float)(h[6] + h[7]);
+#endif
+}
+
+static __device__ __forceinline__ void trellis_accum_abs(uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2,
+ uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
+ const half * h = (const half *)s;
+ s[0] = trellis_next(val1);
+ s[1] = trellis_next(val1);
+ s[2] = trellis_next(val2);
+ s[3] = trellis_next(val2);
+#ifdef GGML_CUDA_F16
+ half h00 = __habs(h[0]+h[1]), h01 = __habs(h[2]+h[3]);
+ half h10 = __habs(h[4]+h[5]), h11 = __habs(h[6]+h[7]);
+ half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01};
+ half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11};
+ bdot1 = __hfma2(y[ 0], h1, bdot1);
+ bdot2 = __hfma2(y[64], h2, bdot2);
+#else
+ bdot1.x += y[ 0].x * fabsf((float)(h[0] + h[1])) * (signs1 & mask1 ? -1 : 1);
+ bdot1.y += y[ 0].y * fabsf((float)(h[2] + h[3])) * (signs2 & mask1 ? -1 : 1);
+ bdot2.x += y[64].x * fabsf((float)(h[4] + h[5])) * (signs1 & mask2 ? -1 : 1);
+ bdot2.y += y[64].y * fabsf((float)(h[6] + h[7])) * (signs2 & mask2 ? -1 : 1);
+#endif
+}
+
+static __device__ __forceinline__ void trellis_accum(const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) {
+#ifdef GGML_CUDA_F16
+ tmp = __hfma2(dl1, bdot1, tmp);
+ tmp = __hfma2(dl2, bdot2, tmp);
+#else
+ tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
+ tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
+#endif
+}
+
+static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = *dptr * 31.75f * 1.05f;
+ const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp = {};
+
+ const int it = threadIdx.x/2;
+ const int ix = threadIdx.x%2;
+
+ uint32_t s[4];
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ const dfloat scale1 = iq4k_values[x[i].scales[it/4] & 0xf];
+ const dfloat scale2 = iq4k_values[x[i].scales[it/4] >> 4];
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[it+ 0] + 4096;
+ uint32_t val2 = ql[it+16] + 4096;
+ for (int k = 0; k < 4; ++k) {
+ trellis_accum(val1, val2, s, y+k, bdot1, bdot2);
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp);
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = d * (float)(tmp.x + tmp.y);
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = *dptr * 31.75f * 1.015f;
+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp = {};
+
+ const int it = threadIdx.x/2;
+ const int ix = threadIdx.x%2;
+
+ uint32_t s[4];
+
+ uint8_t mask1 = 1 << (it/4);
+ uint8_t mask2 = mask1 << 4;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ const uint8_t * qh = x[i].qh;
+ const dfloat scale1 = (x[i].scales[it/4] & 0xf);
+ const dfloat scale2 = (x[i].scales[it/4] >> 4);
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[it+ 0] + 4096;
+ uint32_t val2 = ql[it+16] + 4096;
+ for (int k = 0; k < 4; ++k) {
+ trellis_accum_abs(qh[(8*it+2*k+0)%32], qh[(8*it+2*k+1)%32], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2);
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp);
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = d * (float)(tmp.x + tmp.y);
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ constexpr int kNumGroups = 64;
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = dptr[0] * 31.75f * 1.01f;
+ const float row_av = dptr[1];
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp1 = {};
+ dfloat2 tmp2 = {};
+
+ const int it = threadIdx.x/2; // 0...15
+ const int ix = threadIdx.x%2; // 0 or 1
+
+ uint32_t s[4];
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ const uint32_t offset1 = 4096 + ((shb[it/4+0] & 1) << 15);
+ const uint32_t offset2 = 4096 + ((shb[it/4+4] & 1) << 15);
+ const dfloat scale1 = (int)((shb[it/4+0] & 0xff) >> 1) - 64;
+ const dfloat scale2 = (int)((shb[it/4+4] & 0xff) >> 1) - 64;
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ const uint32_t sh1 = shb[it/4+0] >> (8 + 6*(it%4));
+ const uint32_t sh2 = shb[it/4+4] >> (8 + 6*(it%4));
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
+ uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
+ uint32_t val3 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
+ uint32_t val4 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
+ for (int k = 0; k < 2; ++k) {
+ trellis_accum(val1, val2, s, y+k+0, bdot1, bdot2);
+ trellis_accum(val3, val4, s, y+k+2, bdot1, bdot2);
+#ifdef GGML_CUDA_F16
+ tmp2 += y[k] + y[k+2] + y[k+64] + y[k+66];
+#else
+ tmp2.x += y[k].x + y[k+2].x + y[k+64].x + y[k+66].x;
+ tmp2.y += y[k].y + y[k+2].y + y[k+64].y + y[k+66].y;
+#endif
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp1);
+ }
+
+ // sum up partial sums and write back result
+ float tmp = d * (float)(tmp1.x + tmp1.y) + row_av * (float)(tmp2.x + tmp2.y);
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -554,6 +775,36 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
+static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KT, ncols);
+ dequantize_mul_mat_vec_iq2_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
+static void dequantize_mul_mat_vec_iq3_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KT, ncols);
+ dequantize_mul_mat_vec_iq3_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
+static void dequantize_mul_mat_vec_iq4_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KT, ncols);
+ dequantize_mul_mat_vec_iq4_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
@@ -615,7 +866,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
- src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 ||
+ src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT || src0->type == GGML_TYPE_IQ4_KT;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
@@ -646,6 +898,15 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
+ case GGML_TYPE_IQ2_KT:
+ dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ dequantize_mul_mat_vec_iq3_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_KT:
+ dequantize_mul_mat_vec_iq4_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
@@ -679,5 +940,6 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
+ src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT ||
src0_type == GGML_TYPE_F16;
}
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 14fe2547..d0477835 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -618,3 +618,41 @@ void ggml_cuda_op_mul_mat_vec_q_id(
GGML_UNUSED(src1_ddf_i);
}
+
+bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
+ switch (src0_type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ2_K:
+ case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
+ case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ3_S:
+ return true;
+ default:
+ return false;
+ }
+}
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
index 525c6bc0..d17765f1 100644
--- a/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
+bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,