summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/CMakeLists.txt2
-rw-r--r--ggml/src/ggml-common.h18
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu128
-rw-r--r--ggml/src/ggml-cuda/dmmv.cu264
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu38
-rw-r--r--ggml/src/ggml-cuda/mmvq.cuh1
-rw-r--r--ggml/src/ggml-quants.c3
-rw-r--r--ggml/src/ggml.c66
-rw-r--r--ggml/src/iqk/iqk_gemm_ktquants.cpp403
-rw-r--r--ggml/src/iqk/iqk_gemm_ktquants.h11
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp7
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp1386
-rw-r--r--ggml/src/iqk/iqk_quantize.h18
15 files changed, 2354 insertions, 2 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 9872b3de..b0db417d 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -268,6 +268,7 @@ if (GGML_IQK_MUL_MAT)
iqk/fa/iqk_fa_64_64.cpp
iqk/iqk_gemm_floats.cpp
iqk/iqk_gemm_kquants.cpp
+ iqk/iqk_gemm_ktquants.cpp
iqk/iqk_gemm_iquants.cpp
iqk/iqk_gemm_iqk_quants.cpp
iqk/iqk_gemm_1bit.cpp
@@ -277,6 +278,7 @@ if (GGML_IQK_MUL_MAT)
iqk/fa/iqk_fa_templates.h
iqk/iqk_gemm_floats.h
iqk/iqk_gemm_kquants.h
+ iqk/iqk_gemm_ktquants.h
iqk/iqk_gemm_iquants.h
iqk/iqk_gemm_iqk_quants.h
iqk/iqk_gemm_1bit.h
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 26041ac2..5fe27b29 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -621,6 +621,24 @@ typedef struct {
static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding");
typedef struct {
+ uint8_t scales[QK_K/64];
+ uint8_t ql[QK_K/4];
+} block_iq2_kt;
+static_assert(sizeof(block_iq2_kt) == QK_K/4 + QK_K/64, "wrong iq2_kt block size/padding");
+
+typedef struct {
+ uint8_t scales[QK_K/64];
+ uint8_t ql[QK_K/4];
+ uint8_t qh[QK_K/8];
+} block_iq3_kt;
+static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding");
+
+typedef struct {
+ uint32_t qs[QK_K/8];
+} block_iq4_kt;
+static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding");
+
+typedef struct {
ggml_half d;
uint16_t extra;
uint16_t scales_h;
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 9c8c91f4..f55715f1 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2111,6 +2111,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
+ && ggml_cuda_mmvq_type_supported(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
@@ -3460,6 +3461,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index a04a1929..896ba0df 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -565,6 +565,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 5afe8c74..17604f1c 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -333,6 +333,101 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
+inline __device__ int nearest_int(float fval) {
+ assert(fval <= 4194303.f);
+ float val = fval + 12582912.f;
+ int i; memcpy(&i, &val, sizeof(int));
+ return (i & 0x007fffff) - 0x00400000;
+}
+
+float __device__ __forceinline__ trellis_next(uint32_t& val) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ uint32_t s;
+ const half * h = (const half *)&s;
+ val = ka*val + kb;
+ s = (val & kmask) ^ km32;
+ return (float)(h[0]+h[1]);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const float *)cx;
+ const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ uint32_t idx = ql[ib] + 4096;
+ const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * trellis_next(idx);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const float *)cx;
+ const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ uint32_t idx = ql[ib] + 4096;
+ const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
+ uint8_t mask = 1 << (ib/4);
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const float * dptr = (const float *)((const char *)vx + row * row_size);
+ float scale = dptr[0] * 31.75f * 1.01f;
+ float row_av = dptr[1];
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ constexpr int kNumGroups = 64;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib = tid; // 0...31
+ dst_t * y = yy + ii*QK_K + 8*ib;
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock;
+ const uint8_t * qh = ql + kNumGroups;
+ const int ib32 = ib/4;
+ const int ig = ib%4;
+ const int jj = ib32*8 + 2*ig;
+ uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096;
+ uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset;
+ uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset;
+ int ls = ((shb[ib32] & 0xff) >> 1) - 64;
+ const float dl = scale * ls;
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = dl * trellis_next(idx1) + row_av;
+ y[j+4] = dl * trellis_next(idx2) + row_av;
+ }
+}
+
template<typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@@ -969,6 +1064,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
}
template<typename dst_t>
+static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb = k / QK_K;
+ dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
+}
+
+template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
@@ -1230,6 +1346,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_KT:
+ return dequantize_row_iq2_kt_cuda;
+ case GGML_TYPE_IQ3_KT:
+ return dequantize_row_iq3_kt_cuda;
+ case GGML_TYPE_IQ4_KT:
+ return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
@@ -1303,6 +1425,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_KT:
+ return dequantize_row_iq2_kt_cuda;
+ case GGML_TYPE_IQ3_KT:
+ return dequantize_row_iq3_kt_cuda;
+ case GGML_TYPE_IQ4_KT:
+ return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
index 12738240..50e6458d 100644
--- a/ggml/src/ggml-cuda/dmmv.cu
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -1,3 +1,10 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
#include "dmmv.cuh"
#include "dequantize.cuh"
#include "convert.cuh"
@@ -8,6 +15,220 @@
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
+static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ val = ka*val + kb;
+ return (val & kmask) ^ km32;
+}
+
+static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
+ const half * h = (const half *)s;
+ s[0] = trellis_next(val1);
+ s[1] = trellis_next(val1);
+ s[2] = trellis_next(val2);
+ s[3] = trellis_next(val2);
+#ifdef GGML_CUDA_F16
+ bdot1 = __hfma2(y[ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
+ bdot2 = __hfma2(y[64], {h[4]+h[5], h[6]+h[7]}, bdot2);
+#else
+ bdot1.x += y[ 0].x * (float)(h[0] + h[1]);
+ bdot1.y += y[ 0].y * (float)(h[2] + h[3]);
+ bdot2.x += y[64].x * (float)(h[4] + h[5]);
+ bdot2.y += y[64].y * (float)(h[6] + h[7]);
+#endif
+}
+
+static __device__ __forceinline__ void trellis_accum_abs(uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2,
+ uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
+ const half * h = (const half *)s;
+ s[0] = trellis_next(val1);
+ s[1] = trellis_next(val1);
+ s[2] = trellis_next(val2);
+ s[3] = trellis_next(val2);
+#ifdef GGML_CUDA_F16
+ half h00 = __habs(h[0]+h[1]), h01 = __habs(h[2]+h[3]);
+ half h10 = __habs(h[4]+h[5]), h11 = __habs(h[6]+h[7]);
+ half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01};
+ half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11};
+ bdot1 = __hfma2(y[ 0], h1, bdot1);
+ bdot2 = __hfma2(y[64], h2, bdot2);
+#else
+ bdot1.x += y[ 0].x * fabsf((float)(h[0] + h[1])) * (signs1 & mask1 ? -1 : 1);
+ bdot1.y += y[ 0].y * fabsf((float)(h[2] + h[3])) * (signs2 & mask1 ? -1 : 1);
+ bdot2.x += y[64].x * fabsf((float)(h[4] + h[5])) * (signs1 & mask2 ? -1 : 1);
+ bdot2.y += y[64].y * fabsf((float)(h[6] + h[7])) * (signs2 & mask2 ? -1 : 1);
+#endif
+}
+
+static __device__ __forceinline__ void trellis_accum(const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) {
+#ifdef GGML_CUDA_F16
+ tmp = __hfma2(dl1, bdot1, tmp);
+ tmp = __hfma2(dl2, bdot2, tmp);
+#else
+ tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
+ tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
+#endif
+}
+
+static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = *dptr * 31.75f * 1.05f;
+ const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp = {};
+
+ const int it = threadIdx.x/2;
+ const int ix = threadIdx.x%2;
+
+ uint32_t s[4];
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ const dfloat scale1 = iq4k_values[x[i].scales[it/4] & 0xf];
+ const dfloat scale2 = iq4k_values[x[i].scales[it/4] >> 4];
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[it+ 0] + 4096;
+ uint32_t val2 = ql[it+16] + 4096;
+ for (int k = 0; k < 4; ++k) {
+ trellis_accum(val1, val2, s, y+k, bdot1, bdot2);
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp);
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = d * (float)(tmp.x + tmp.y);
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = *dptr * 31.75f * 1.015f;
+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp = {};
+
+ const int it = threadIdx.x/2;
+ const int ix = threadIdx.x%2;
+
+ uint32_t s[4];
+
+ uint8_t mask1 = 1 << (it/4);
+ uint8_t mask2 = mask1 << 4;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ const uint8_t * qh = x[i].qh;
+ const dfloat scale1 = (x[i].scales[it/4] & 0xf);
+ const dfloat scale2 = (x[i].scales[it/4] >> 4);
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[it+ 0] + 4096;
+ uint32_t val2 = ql[it+16] + 4096;
+ for (int k = 0; k < 4; ++k) {
+ trellis_accum_abs(qh[(8*it+2*k+0)%32], qh[(8*it+2*k+1)%32], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2);
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp);
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = d * (float)(tmp.x + tmp.y);
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
+ const int ncols, int nrows, int64_t row_size) {
+
+ constexpr int kNumGroups = 64;
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const float * dptr = (const float *)((const char *)vx + row*row_size);
+ const float d = dptr[0] * 31.75f * 1.01f;
+ const float row_av = dptr[1];
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+
+ const int num_blocks_per_row = ncols / QK_K;
+
+ dfloat2 tmp1 = {};
+ dfloat2 tmp2 = {};
+
+ const int it = threadIdx.x/2; // 0...15
+ const int ix = threadIdx.x%2; // 0 or 1
+
+ uint32_t s[4];
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+ const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ const uint32_t offset1 = 4096 + ((shb[it/4+0] & 1) << 15);
+ const uint32_t offset2 = 4096 + ((shb[it/4+4] & 1) << 15);
+ const dfloat scale1 = (int)((shb[it/4+0] & 0xff) >> 1) - 64;
+ const dfloat scale2 = (int)((shb[it/4+4] & 0xff) >> 1) - 64;
+ const dfloat2 dl1 = {scale1, scale1};
+ const dfloat2 dl2 = {scale2, scale2};
+ const uint32_t sh1 = shb[it/4+0] >> (8 + 6*(it%4));
+ const uint32_t sh2 = shb[it/4+4] >> (8 + 6*(it%4));
+ dfloat2 bdot1 = {0, 0};
+ dfloat2 bdot2 = {0, 0};
+ uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
+ uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
+ uint32_t val3 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
+ uint32_t val4 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
+ for (int k = 0; k < 2; ++k) {
+ trellis_accum(val1, val2, s, y+k+0, bdot1, bdot2);
+ trellis_accum(val3, val4, s, y+k+2, bdot1, bdot2);
+#ifdef GGML_CUDA_F16
+ tmp2 += y[k] + y[k+2] + y[k+64] + y[k+66];
+#else
+ tmp2.x += y[k].x + y[k+2].x + y[k+64].x + y[k+66].x;
+ tmp2.y += y[k].y + y[k+2].y + y[k+64].y + y[k+66].y;
+#endif
+ }
+ trellis_accum(dl1, dl2, bdot1, bdot2, tmp1);
+ }
+
+ // sum up partial sums and write back result
+ float tmp = d * (float)(tmp1.x + tmp1.y) + row_av * (float)(tmp2.x + tmp2.y);
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -554,6 +775,36 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
+static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KT, ncols);
+ dequantize_mul_mat_vec_iq2_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
+static void dequantize_mul_mat_vec_iq3_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KT, ncols);
+ dequantize_mul_mat_vec_iq3_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
+static void dequantize_mul_mat_vec_iq4_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ constexpr int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KT, ncols);
+ dequantize_mul_mat_vec_iq4_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size);
+}
+
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
@@ -615,7 +866,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
- src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 ||
+ src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT || src0->type == GGML_TYPE_IQ4_KT;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
@@ -646,6 +898,15 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
+ case GGML_TYPE_IQ2_KT:
+ dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ dequantize_mul_mat_vec_iq3_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_KT:
+ dequantize_mul_mat_vec_iq4_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
@@ -679,5 +940,6 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
+ src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT ||
src0_type == GGML_TYPE_F16;
}
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 14fe2547..d0477835 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -618,3 +618,41 @@ void ggml_cuda_op_mul_mat_vec_q_id(
GGML_UNUSED(src1_ddf_i);
}
+
+bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
+ switch (src0_type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ2_K:
+ case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
+ case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ3_S:
+ return true;
+ default:
+ return false;
+ }
+}
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
index 525c6bc0..d17765f1 100644
--- a/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
+bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 0e6aa677..220c0c99 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15421,6 +15421,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q6_0: break;
case GGML_TYPE_IQ2_K: break;
case GGML_TYPE_IQ2_KS: break;
+ case GGML_TYPE_IQ2_KT: break;
+ case GGML_TYPE_IQ3_KT: break;
+ case GGML_TYPE_IQ4_KT: break;
case GGML_TYPE_IQ3_K: break;
case GGML_TYPE_IQ4_K: break;
case GGML_TYPE_IQ5_K: break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 7cbc0056..d8025a5a 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1574,6 +1574,45 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 2,
},
+ [GGML_TYPE_IQ2_KT] = {
+ .type_name = "iq2_kt",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_kt),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_kt,
+ .from_float = quantize_row_iq2_kt,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref,
+ .vec_dot = vec_dot_iq2_kt_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 4,
+ },
+ [GGML_TYPE_IQ3_KT] = {
+ .type_name = "iq3_kt",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq3_kt),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_kt,
+ .from_float = quantize_row_iq3_kt,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref,
+ .vec_dot = vec_dot_iq3_kt_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 4,
+ },
+ [GGML_TYPE_IQ4_KT] = {
+ .type_name = "iq4_kt",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_kt),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_kt,
+ .from_float = quantize_row_iq4_kt,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
+ .vec_dot = vec_dot_iq4_kt_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 8,
+ },
[GGML_TYPE_IQ3_K] = {
.type_name = "iq3_k",
.blck_size = QK_K,
@@ -4501,6 +4540,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
case GGML_FTYPE_MOSTLY_IQ2_K_R4: wtype = GGML_TYPE_IQ2_K_R4; break;
case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break;
+ case GGML_FTYPE_MOSTLY_IQ2_KT: wtype = GGML_TYPE_IQ2_KT; break;
+ case GGML_FTYPE_MOSTLY_IQ3_KT: wtype = GGML_TYPE_IQ3_KT; break;
+ case GGML_FTYPE_MOSTLY_IQ4_KT: wtype = GGML_TYPE_IQ4_KT; break;
case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break;
case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
case GGML_FTYPE_MOSTLY_IQ3_K_R4: wtype = GGML_TYPE_IQ3_K_R4; break;
@@ -11266,6 +11308,9 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -11740,6 +11785,9 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -11911,6 +11959,9 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -15409,6 +15460,9 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -15820,6 +15874,9 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -16137,6 +16194,9 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -16771,6 +16831,9 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_K_R4:
@@ -23841,6 +23904,9 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K_R4:result = quantize_iq2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_K_R4:result = quantize_iq3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp
new file mode 100644
index 00000000..c38dcdc6
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp
@@ -0,0 +1,403 @@
+#include "iqk_gemm_ktquants.h"
+#include "ggml.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#ifdef __x86_64__
+
+namespace {
+
+static inline uint32_t trellis_next(uint32_t& val) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ val = val*ka + kb;
+ return (val & kmask) ^ km32;
+}
+
+static inline __m256i trellis_next8(uint32_t val) {
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka1 = ka*ka;
+ constexpr uint32_t kb1 = kb*ka+kb;
+ constexpr uint32_t ka2 = ka1*ka;
+ constexpr uint32_t kb2 = kb1*ka+kb;
+ constexpr uint32_t ka3 = ka2*ka;
+ constexpr uint32_t kb3 = kb2*ka+kb;
+ constexpr uint32_t ka4 = ka3*ka;
+ constexpr uint32_t kb4 = kb3*ka+kb;
+ constexpr uint32_t ka5 = ka4*ka;
+ constexpr uint32_t kb5 = kb4*ka+kb;
+ constexpr uint32_t ka6 = ka5*ka;
+ constexpr uint32_t kb6 = kb5*ka+kb;
+ constexpr uint32_t ka7 = ka6*ka;
+ constexpr uint32_t kb7 = kb6*ka+kb;
+ __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7);
+ __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7);
+ __m256i mval = _mm256_set1_epi32(val);
+ __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
+ return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32);
+}
+
+static inline float trellis_gen(uint32_t& val, uint32_t* s) {
+ const ggml_fp16_t * h = (const ggml_fp16_t *)s;
+ s[0] = trellis_next(val);
+ return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]);
+}
+
+struct Trellis1 {
+ constexpr static uint32_t kmask = 0x8fff8fff;
+ constexpr static uint32_t km32 = 0x3b603b60;
+ constexpr static uint32_t ka = 89226354;
+ constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka1 = ka*ka;
+ constexpr static uint32_t kb1 = kb*ka+kb;
+ constexpr static uint32_t ka2 = ka1*ka;
+ constexpr static uint32_t kb2 = kb1*ka+kb;
+ constexpr static uint32_t ka3 = ka2*ka;
+ constexpr static uint32_t kb3 = kb2*ka+kb;
+ constexpr static uint32_t ka4 = ka3*ka;
+ constexpr static uint32_t kb4 = kb3*ka+kb;
+ constexpr static uint32_t ka5 = ka4*ka;
+ constexpr static uint32_t kb5 = kb4*ka+kb;
+ constexpr static uint32_t ka6 = ka5*ka;
+ constexpr static uint32_t kb6 = kb5*ka+kb;
+ constexpr static uint32_t ka7 = ka6*ka;
+ constexpr static uint32_t kb7 = kb6*ka+kb;
+ const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7);
+ const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7);
+ const __m256i mask1 = _mm256_set1_epi32(kmask);
+ const __m256i mask2 = _mm256_set1_epi32(km32);
+
+ inline __m256i next8(uint32_t val) const {
+ auto mval = _mm256_set1_epi32(val);
+ auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
+ return _mm256_and_si256(mres, mask1) ^ mask2;
+ }
+};
+
+static inline __m256 trellis_gen8(__m256i i8) {
+ // split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
+ __m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
+ __m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
+ __m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
+ // 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7
+ auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32);
+ // 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7
+ iv = _mm256_permute4x64_epi64(iv, 0xd8);
+ auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0));
+ auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1));
+ return _mm256_add_ps(fv1, fv2);
+}
+
+struct Trellis2 {
+ constexpr static uint32_t kmask = 0x8fff8fff;
+ constexpr static uint32_t km32 = 0x3b603b60;
+ constexpr static uint32_t ka = 89226354;
+ constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka1 = ka*ka;
+ constexpr static uint32_t kb1 = kb*ka+kb;
+ constexpr static uint32_t ka2 = ka1*ka;
+ constexpr static uint32_t kb2 = kb1*ka+kb;
+ constexpr static uint32_t ka3 = ka2*ka;
+ constexpr static uint32_t kb3 = kb2*ka+kb;
+ __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3);
+ __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3);
+ const __m256i mask1 = _mm256_set1_epi32(kmask);
+ const __m256i mask2 = _mm256_set1_epi32(km32);
+
+ inline __m256i next8(uint32_t val1, uint32_t val2) {
+ __m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
+ __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
+ return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32);
+ }
+};
+
+template <int nrc_y>
+static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis1 trellis;
+
+ auto shifts = _mm_set_epi32(0, 0, 4, 0);
+ auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
+
+ union { __m256 vec; float val[8]; } s_helper;
+
+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
+ __m256 accd[k_acc];
+ const float * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ const float d = *dptr * 31.75f * 1.05f;
+ const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
+ s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
+ s8 = _mm_shuffle_epi8(values, s8);
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ s_helper.vec = _mm256_cvtepi32_ps(s32);
+ for (int ib = 0; ib < QK_K/64; ++ib) {
+ auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]);
+ auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]);
+ for (int j = 0; j < 4; ++j) {
+ auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096)));
+ auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096)));
+ if constexpr (nrc_y == 1) {
+ accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[0]);
+ accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[1]);
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[iy]);
+ accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[iy]);
+ }
+ }
+ }
+ }
+ }
+
+ if constexpr (nrc_y == 1) {
+ __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_add_ps(accd[0], accd[1]));
+ info.store(ix, 0, hsum_float_8(res));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
+ info.store(ix, iy, hsum_float_8(res));
+ }
+ }
+ }
+}
+
+static inline __m256 abs_ps(__m256 vals) {
+ // Clear sign-bit of all the 32-bit floats in vals
+ __m256 sign_bit = _mm256_set1_ps(-0.0f);
+ return _mm256_andnot_ps(sign_bit, vals);
+}
+
+// Negates 32-bit float lanes of an 8x32-bit vector
+// based on 8x8-bit condition var. For float lane i, if byte i of
+// `condition` is nonzero, the float will be negated.
+static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) {
+ __m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64);
+ // Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
+ // else 0x00 (upper bytes are meaningless)
+ __m128i zeros = _mm_setzero_si128();
+ __m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros);
+ __m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros);
+ // Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
+ // expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
+ __m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask);
+ // Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
+ __m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256());
+ // Negate via XOR on sign bits of each 32-bit float
+ __m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value
+ __m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern);
+ __m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32);
+ return _mm256_xor_ps(vals, xor_mask_ps);
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis1 trellis;
+
+ __m256 accd[nrc_y];
+ const float * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ const float d = *dptr * 31.75f * 1.015f;
+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ const uint8_t * qh = x[i].qh;
+ for (int j = 0; j < 128; j+=8) {
+ uint64_t mask1 = 0x0101010101010101 << (j/32);
+ uint64_t mask2 = mask1 << 4;
+ uint32_t val1 = ql[j/8] + 4096;
+ uint32_t val2 = ql[j/8+16] + 4096;
+ const uint64_t signs = *((const uint64_t *)(qh + (j%32)));
+ const float x_scale1 = (x[i].scales[j/32] & 0xf);
+ const float x_scale2 = (x[i].scales[j/32] >> 4);
+ const __m256 x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
+ const __m256 x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(
+ conditional_negate_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1
+ ),
+ _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
+ accd[iy]
+ );
+ accd[iy] = _mm256_fmadd_ps(
+ conditional_negate_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2
+ ),
+ _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
+ accd[iy]
+ );
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
+ info.store(ix, iy, hsum_float_8(res));
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+ constexpr int kNumGroups = 64;
+
+ Trellis2 trellis;
+
+ __m256 accd[nrc_y];
+ __m256 accd2[nrc_y];
+ const float * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ const float d = dptr[0] * 31.75f * 1.01f;
+ const float row_av = dptr[1];
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_setzero_ps();
+ accd2[iy] = _mm256_setzero_ps();
+ }
+
+ for (int i = 0; i < nb; ++i) {
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ for (int j = 0; j < 128; j+=8) {
+ const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
+ const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
+ const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
+ const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
+ const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
+ const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
+ uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
+ uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
+ uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
+ uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
+ const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3));
+ const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j),
+ _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
+ accd[iy]
+ );
+ accd[iy] = _mm256_fmadd_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j+128),
+ _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
+ accd[iy]
+ );
+ accd2[iy] = _mm256_add_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j),
+ accd2[iy]
+ );
+ accd2[iy] = _mm256_add_ps(
+ _mm256_load_ps(y[iy] + i*QK_K+j+128),
+ accd2[iy]
+ );
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
+ __m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]);
+ info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2));
+ }
+ }
+}
+
+} // namespace
+
+bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ2_KT:
+ assert (ne00 % QK_K == 0);
+ kernels[0] = mul_mat_iq2_kt_F32_T<1>;
+ kernels[1] = mul_mat_iq2_kt_F32_T<2>;
+ kernels[2] = mul_mat_iq2_kt_F32_T<3>;
+ kernels[3] = mul_mat_iq2_kt_F32_T<4>;
+ kernels[4] = mul_mat_iq2_kt_F32_T<5>;
+ kernels[5] = mul_mat_iq2_kt_F32_T<6>;
+ kernels[6] = mul_mat_iq2_kt_F32_T<7>;
+ kernels[7] = mul_mat_iq2_kt_F32_T<8>;
+ break;
+ case GGML_TYPE_IQ3_KT:
+ assert (ne00 % QK_K == 0);
+ kernels[0] = mul_mat_iq3_kt_F32_T<1>;
+ kernels[1] = mul_mat_iq3_kt_F32_T<2>;
+ kernels[2] = mul_mat_iq3_kt_F32_T<3>;
+ kernels[3] = mul_mat_iq3_kt_F32_T<4>;
+ kernels[4] = mul_mat_iq3_kt_F32_T<5>;
+ kernels[5] = mul_mat_iq3_kt_F32_T<6>;
+ kernels[6] = mul_mat_iq3_kt_F32_T<7>;
+ kernels[7] = mul_mat_iq3_kt_F32_T<8>;
+ break;
+ case GGML_TYPE_IQ4_KT:
+ assert (ne00 % QK_K == 0);
+ kernels[0] = mul_mat_iq4_kt_F32_T<1>;
+ kernels[1] = mul_mat_iq4_kt_F32_T<2>;
+ kernels[2] = mul_mat_iq4_kt_F32_T<3>;
+ kernels[3] = mul_mat_iq4_kt_F32_T<4>;
+ kernels[4] = mul_mat_iq4_kt_F32_T<5>;
+ kernels[5] = mul_mat_iq4_kt_F32_T<6>;
+ kernels[6] = mul_mat_iq4_kt_F32_T<7>;
+ kernels[7] = mul_mat_iq4_kt_F32_T<8>;
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#else // !__x86_64__
+
+bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+ return false;
+}
+
+#endif
+
+#endif \ No newline at end of file
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.h b/ggml/src/iqk/iqk_gemm_ktquants.h
new file mode 100644
index 00000000..b1e84d63
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_ktquants.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index abf14ed0..43be0885 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -22,6 +22,7 @@
#include "iqk_flash_impl.h"
#include "iqk_gemm_floats.h"
#include "iqk_gemm_kquants.h"
+#include "iqk_gemm_ktquants.h"
#include "iqk_gemm_iquants.h"
#include "iqk_gemm_iqk_quants.h"
#include "iqk_gemm_1bit.h"
@@ -541,6 +542,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_KS_R4:
return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
+ return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -921,4 +926,4 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
return false;
}
-#endif
+#endif \ No newline at end of file
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 93aa2180..c1f7a8e4 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -22,6 +22,8 @@
#include <algorithm>
#include <cstring>
#include <mutex>
+#include <random>
+#include <memory>
#include <thread>
#include <atomic>
#include <unordered_map>
@@ -7408,3 +7410,1387 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) {
}
}
+namespace {
+#ifdef __AVX2__
+__m128 hsum_float_4x4(__m128 * accm) {
+ accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2]));
+ accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3]));
+ return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1]));
+}
+__m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+__m256 hsum_float_4x8(__m256 * accm) {
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+#endif
+template <int block_size, int group_size, int num_bits, bool is_abs = false>
+class QuantizerIQKT {
+ static_assert(group_size == 8 || group_size == 4);
+ static_assert(block_size >= 8 && block_size%8 == 0);
+public:
+ constexpr static int kSuperBlockSize = QK_K;
+ constexpr static int kBlockSize = block_size;
+ constexpr static int kGroupSize = group_size;
+ constexpr static int kNg = kBlockSize/kGroupSize;
+ constexpr static int kNblock = kSuperBlockSize/kBlockSize;
+ constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8
+ constexpr static float kScale = 31.75f;
+ constexpr static bool kVerbose = false;
+
+ QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096);
+ const float * values() const { return m_values.data(); }
+
+ inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const;
+ inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const;
+ inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const;
+
+ static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ uint32_t x = i + offset;
+ for (int k = 0; k < kGroupSize; ++k) {
+ x = ka*x + kb;
+ uint32_t s = (x & kmask) ^ km32;
+ float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
+ if constexpr (is_abs) result[k] = scale*std::abs(val);
+ else result[k] = scale*val;
+ }
+ }
+
+ static inline int bin4(float x) {
+ if constexpr (is_abs) {
+ return x < 16.f ? 0 : x < 32.f ? 1 : x < 64.f ? 2 : 3;
+ } else {
+ return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3;
+ }
+ }
+ static inline int bin5(float x) {
+ if constexpr (is_abs) {
+ return x < 11.2f ? 0 : x < 24.f ? 1 : x < 39.f ? 2 : x < 58.f ? 3 : 4;
+ } else {
+ return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4;
+ }
+ }
+ inline int bin3(int idim, float x) const { return x < m_mid[2*idim+0] ? 0 : x < m_mid[2*idim+1] ? 1 : 2; }
+
+ static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) {
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ const float * xbl = x + ibl*kSuperBlockSize;
+ float * wbl = row_weights + ibl*kSuperBlockSize;
+
+ float sumx2 = 0;
+ for (int j = 0; j < kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j];
+ const float sigma2 = sigma2_scale*sumx2/kSuperBlockSize;
+
+ if (imatrix) {
+ const float * qw = imatrix + ibl*kSuperBlockSize;
+ for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = qw[j] * sqrtf(sigma2 + xbl[j]*xbl[j]);
+ } else {
+ for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = 0.25f*sigma2 + xbl[j]*xbl[j];
+ }
+ }
+ }
+private:
+ static std::vector<float> cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid);
+ static std::vector<std::vector<int>> finalize_clusters(int num_neighbours, const std::vector<float>& points, const std::vector<float>& clusters,
+ std::vector<std::vector<float>>& c_values);
+ std::vector<float> m_values;
+ std::vector<float> m_clusters;
+ std::vector<std::vector<int>> m_in_cluster;
+ std::vector<std::vector<float>> m_c_values;
+ float m_mid[4*kGroupSize];
+};
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
+ m_values.resize(kNumVal*kGroupSize);
+ float * data = m_values.data();
+ for (int i = 0; i < kNumVal; ++i) {
+ set_values(i, data, kScale, offset);
+ data += kGroupSize;
+ }
+ // Make 128 clusters.
+ // Note: we get a slightly better result by using 64 clusters
+ // at the expense of almost doubling the quantization time.
+ m_clusters = cluster_points(m_values, num_clusters, 200, m_mid);
+ GGML_ASSERT(!m_clusters.empty());
+ m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values);
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale(
+ const float * xb, const float * weight, const int * best_idx) const {
+ float sumqx = 0, sumq2 = 0;
+#ifdef __AVX2__
+ auto vqx = _mm256_setzero_ps();
+ auto vq2 = _mm256_setzero_ps();
+ for (int l = 0; l < kBlockSize; l += 8) {
+ auto vx = _mm256_loadu_ps(xb+l);
+ auto vw = _mm256_loadu_ps(weight+l);
+ auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
+ _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
+ _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
+ auto vqw = _mm256_mul_ps(vq, vw);
+ vqx = _mm256_fmadd_ps(vqw, vx, vqx);
+ vq2 = _mm256_fmadd_ps(vqw, vq, vq2);
+ }
+ sumqx = hsum_float_8(vqx);
+ sumq2 = hsum_float_8(vq2);
+#else
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + kGroupSize*l;
+ auto wl = weight + kGroupSize*l;
+ auto ql = m_values.data() + kGroupSize*best_idx[l];
+ for (int k = 0; k < kGroupSize; ++k) {
+ sumqx += wl[k]*ql[k]*xl[k];
+ sumq2 += wl[k]*ql[k]*ql[k];
+ }
+ }
+#endif
+ return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale(
+ const float * xb, const float * weight, const int * best_idx) const {
+ float sumqx = 0, sumx2 = 0;
+#ifdef __AVX2__
+ auto vqx = _mm256_setzero_ps();
+ auto vx2 = _mm256_setzero_ps();
+ for (int l = 0; l < kBlockSize; l += 8) {
+ auto vx = _mm256_loadu_ps(xb+l);
+ auto vw = _mm256_loadu_ps(weight+l);
+ auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
+ _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
+ _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
+ auto vxw = _mm256_mul_ps(vx, vw);
+ vx2 = _mm256_fmadd_ps(vxw, vx, vx2);
+ vqx = _mm256_fmadd_ps(vxw, vq, vqx);
+ }
+ sumqx = hsum_float_8(vqx);
+ sumx2 = hsum_float_8(vx2);
+#else
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + kGroupSize*l;
+ auto wl = weight + kGroupSize*l;
+ auto ql = m_values.data() + kGroupSize*best_idx[l];
+ for (int k = 0; k < kGroupSize; ++k) {
+ sumqx += wl[k]*ql[k]*xl[k];
+ sumx2 += wl[k]*xl[k]*xl[k];
+ }
+ }
+#endif
+ return sumx2 > 0 ? sumqx/sumx2 : 0.f;
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
+ if (!d) {
+ std::memset(best_idx, 0, kNg*sizeof(int));
+ return;
+ }
+ int ncluster = m_clusters.size()/kGroupSize;
+ float id = 1/d;
+#ifdef __AVX2__
+ if constexpr (kGroupSize == 8) {
+ __m256 sqx[8];
+ const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
+ float sx[8];
+ int index[8];
+ auto vid = _mm256_set1_ps(id);
+ auto add8 = _mm256_set1_epi32(8);
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + 8*l;
+ auto wl = weight + 8*l;
+ auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl));
+ auto vw = _mm256_loadu_ps(wl);
+ int jbest = -1;
+ if (kGroupSize == 8 && (ncluster == 256 || ncluster == 6561)) {
+ _mm256_store_ps(sx, vx);
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j);
+ } else {
+ int s = 1;
+ for (int j = 0; j < 8; ++j) { u += s*bin3(j, sx[j]); s *= 3; }
+ }
+ jbest = u;
+ } else {
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY;
+ auto idx = add_idx;
+ for (int j = 0; j < ncluster; j += 8) {
+ for (int i = 0; i < 8; ++i) {
+ auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_8x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ }
+ auto& points = m_in_cluster[jbest];
+ auto& values = points.empty() ? m_values : m_c_values[jbest];
+ int npoint = values.size()/kGroupSize;
+ GGML_ASSERT(npoint > 0 && npoint%8 == 0);
+ int jbest_cluster = jbest;
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ auto best = INFINITY; jbest = -1;
+ auto idx = add_idx;
+ for (int j = 0; j < npoint; j += 8) {
+ for (int i = 0; i < 8; ++i) {
+ auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_8x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ if (jbest < 0) {
+ fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
+ GGML_ASSERT(false);
+ }
+ best_idx[l] = points.empty() ? jbest : points[jbest];
+ }
+ } else {
+ __m256 sqx[4];
+ const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+ const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff));
+ float sx[8];
+ int index[8];
+ auto vid_p = _mm256_set1_ps(id);
+ auto add8 = _mm256_set1_epi32(8);
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + 4*l;
+ auto wl = weight + 4*l;
+ auto vx4 = _mm_loadu_ps(xl);
+ auto vx = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4));
+ auto vw4 = _mm_loadu_ps(wl);
+ auto vw = _mm256_set_m128(vw4, vw4);
+ int jbest = -1;
+ if (ncluster == 256 || ncluster == 625) {
+ _mm256_storeu_ps(sx, vx);
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k);
+ } else {
+ int l = 1;
+ for (int k = 0; k < 4; ++k) { u += bin5(sx[k])*l; l *= 5; }
+ }
+ jbest = u;
+ } else {
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY;
+ auto idx = add_idx;
+ for (int j = 0; j < ncluster; j += 8) {
+ for (int i = 0; i < 4; ++i) {
+ auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ vdiff = _mm256_and_ps(sign_bit, vdiff);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
+ }
+ auto score = hsum_float_4x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ }
+ auto& points = m_in_cluster[jbest];
+ auto& values = m_c_values[jbest];
+ GGML_ASSERT(!points.empty() && points.size()%8 == 0);
+ int jbest_cluster = jbest;
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY; jbest = -1;
+ auto idx = add_idx;
+ for (int j = 0; j < int(points.size()); j += 8) {
+ for (int i = 0; i < 4; ++i) {
+ auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_4x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ if (jbest < 0) {
+ fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
+ GGML_ASSERT(false);
+ }
+ best_idx[l] = points[jbest];
+ }
+ }
+#else
+ // TODO
+ std::memset(best_idx, 0, kNg*sizeof(int));
+#endif
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours,
+ const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) {
+ int ncluster = clusters.size()/kGroupSize;
+ std::vector<std::vector<int>> p_in_cluster(ncluster);
+ std::vector<int> which_cluster(num_neighbours*kNumVal);
+ std::vector<int> ibest(num_neighbours);
+ std::vector<float> best(num_neighbours);
+ for (int ip = 0; ip < kNumVal; ++ip) {
+ auto vp = values.data() + ip*kGroupSize;
+ for (int j = 0; j < num_neighbours; ++j) {
+ best[j] = INFINITY; ibest[j] = -1;
+ }
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto vc = clusters.data() + ic*kGroupSize;
+ float dist2 = 0;
+ for (int k = 0; k < kGroupSize; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (dist2 < best[j]) {
+ for (int k = num_neighbours-1; k > j; --k) {
+ best[k] = best[k-1]; ibest[k] = ibest[k-1];
+ }
+ best[j] = dist2; ibest[j] = ic;
+ break;
+ }
+ }
+ }
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (ibest[j] < 0) {
+ printf("Oops: ibest[%d] = %d\n", j, ibest[j]);
+ }
+ GGML_ASSERT(ibest[j] >= 0);
+ p_in_cluster[ibest[j]].push_back(ip);
+ }
+ std::memcpy(which_cluster.data() + num_neighbours*ip, ibest.data(), num_neighbours*sizeof(int));
+ }
+ std::vector<std::pair<float, int>> extra;
+ extra.reserve(kNumVal);
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto& points = p_in_cluster[ic];
+ if (!points.empty() && points.size()%8 == 0) continue;
+ extra.clear();
+ auto vc = clusters.data() + ic*kGroupSize;
+ for (int ip = 0; ip < kNumVal; ++ip) {
+ bool can_add = true;
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (which_cluster[num_neighbours*ip+j] == ic) { can_add = false; break; }
+ }
+ if (!can_add) continue;
+ auto vp = values.data() + ip*kGroupSize;
+ float dist2 = 0;
+ for (int k = 0; k < kGroupSize; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ extra.push_back(std::make_pair(dist2, ip));
+ }
+ std::sort(extra.begin(), extra.end());
+ int nadd = 8*((points.size()+7)/8) - points.size();
+ for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second);
+ GGML_ASSERT(points.size()%8 == 0);
+ }
+ auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size();
+ for (auto& points : p_in_cluster) {
+ min = std::min(min, points.size());
+ max = std::max(max, points.size());
+ }
+ c_values.resize(p_in_cluster.size());
+ for (int i = 0; i < int(p_in_cluster.size()); ++i) {
+ auto& points = p_in_cluster[i];
+ c_values[i].resize(points.size()*kGroupSize);
+ auto ptr = c_values[i].data();
+ for (auto j : points) {
+ std::memcpy(ptr, values.data() + j*kGroupSize, kGroupSize*sizeof(float));
+ ptr += kGroupSize;
+ }
+ }
+
+ if (kVerbose) {
+ printf("%s: prepared %d clusters\n", __func__, ncluster);
+ printf(" min number of points in a cluster: %d\n", int(min));
+ printf(" max number of points in a cluster: %d\n", int(max));
+ }
+ return p_in_cluster;
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
+ constexpr int ndim = kGroupSize;
+ GGML_ASSERT(points.size() % ndim == 0);
+ int npoint = points.size() / ndim;
+ GGML_ASSERT(npoint >= 2*ncluster);
+ std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY));
+ double Fo = 0;
+ for (int i = 0; i < npoint; ++i) {
+ auto v = points.data() + i*ndim;
+ for (int k = 0; k < ndim; ++k) {
+ Fo += v[k]*v[k];
+ range[k].first = std::min(range[k].first, v[k]);
+ range[k].second = std::max(range[k].second, v[k]);
+ }
+ }
+ if (kVerbose) printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size());
+ if constexpr (is_abs) {
+ std::vector<int> P(npoint);
+ for (int idim = 0; idim < ndim; ++idim) {
+ for (int ip = 0; ip < npoint; ++ip) P[ip] = points[ip*ndim+idim];
+ std::sort(P.begin(), P.end());
+ if (ndim == 8 && ncluster == 6561) {
+ mid[2*idim + 0] = P[npoint/3];
+ mid[2*idim + 1] = P[2*npoint/3];
+ } else {
+ mid[idim] = npoint%2 == 0 ? 0.5f*(P[npoint/2] + P[npoint/2-1]) : P[npoint/2];
+ if (kVerbose) printf("%s: mid[%d] = %g\n", __func__, idim, mid[idim]);
+ }
+ }
+ } else {
+ for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second);
+ }
+ std::vector<float> sump(ncluster*ndim);
+ std::vector<int> counts(ncluster);
+ std::vector<float> result(ncluster*ndim);
+ if (ndim == 8 && (ncluster == 256 || ncluster == 6561)) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k);
+ } else {
+ int s = 1;
+ for (int k = 0; k < ndim; ++k) {
+ int bin = vp[k] < mid[2*k+0] ? 0 : vp[k] < mid[2*k+1] ? 1 : 2;
+ u += s*bin; s *= 3;
+ }
+ }
+ ++counts[u];
+ for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
+ }
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) {
+ printf("%s: Oops. Cluster %d has no points\n", __func__, ic);
+ GGML_ABORT("fatal error");
+ }
+ for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic];
+ }
+ return result;
+ }
+ else if (ndim == 4 && (ncluster == 256 || ncluster == 625)) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k);
+ } else {
+ int s = 1;
+ for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; }
+ }
+ if (u >= int(counts.size())) {
+ printf("Oops: u = %u, vp = %g, %g, %g, %g\n", u, vp[0], vp[1], vp[2], vp[3]);
+ u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) {
+ auto bin = bin4(vp[k]); u |= (bin << 2*k);
+ printf(" bin[%d] = %d, u = %u", k, bin, u);
+ }
+ } else {
+ for (int k = 0; k < ndim; ++k) printf(" bin[%d] = %d", k, bin5(vp[k]));
+ }
+ printf("\n");
+ GGML_ABORT("fatal error");
+ }
+ ++counts[u];
+ for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
+ }
+ int nzero = 0;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) {
+ ++nzero;
+ printf("%s: Oops. Cluster %d has no points: ", __func__, ic);
+ for (int k = 0; k < ndim; ++k) {
+ int l = (ic >> 2*k) & 3;
+ printf(" %d", l);
+ }
+ printf("\n");
+ } else {
+ for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic];
+ }
+ }
+ if (nzero > 0) printf("%s: %d out of %d clusters dir not have any points\n", __func__, nzero, ncluster);
+ return result;
+ }
+ std::mt19937 rndm(1234);
+ float scale = 1.f/4294967296.f;
+ for (int i = 0; i < ncluster; ++i) {
+ auto v = result.data() + i*ndim;
+ for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm();
+ }
+ std::vector<int> which_cluster(npoint, -1);
+ double Flast = Fo;
+ for (int iter = 0; iter < niter; ++iter) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ int nchanged = 0;
+ double F = 0;
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ float best = INFINITY; int ibest = -1;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto vc = result.data() + ndim*ic;
+ float dist2 = 0;
+ for (int k = 0; k < ndim; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ if (dist2 < best) {
+ best = dist2; ibest = ic;
+ }
+ }
+ if (ibest < 0) {
+ printf("Oops(iteration %d) - failed to find cluster for point", iter);
+ for (int k = 0; k < ndim; ++k) printf(" %g", vp[k]);
+ printf("\nHave %d clusters\n", ncluster);
+ }
+ GGML_ASSERT(ibest >= 0);
+ F += best;
+ if (which_cluster[ip] != ibest) ++nchanged;
+ which_cluster[ip] = ibest;
+ ++counts[ibest];
+ auto vc = sump.data() + ndim*ibest;
+ for (int k = 0; k < ndim; ++k) vc[k] += vp[k];
+ }
+ if (nchanged == 0) break;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f;
+ auto vc = sump.data() + ndim*ic;
+ auto r = result.data() + ndim*ic;
+ for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm;
+ }
+ if (kVerbose) printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged);
+ if (iter > 1 && Flast/F - 1 < 1e-6) break;
+ Flast = F;
+ }
+ int nzero = 0;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) ++nzero;
+ }
+ if (nzero > 0) printf("%s: there are %d empty clusters\n", __func__, nzero);
+ return result;
+}
+
+// ========================================== iq2_kt ====================================================
+
+using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>;
+
+const QuantizerIQ2KT& iq2kt_quantizer() {
+ static std::mutex mutex;
+ static std::unique_ptr<QuantizerIQ2KT> quantizer;
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!quantizer) quantizer = std::make_unique<QuantizerIQ2KT>(256, 8);
+ return *quantizer;
+}
+
+void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights,
+ float * qtmp) {
+
+ constexpr float kSigmaScale = 2.0f;
+ using Q = QuantizerIQ2KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ float * dptr = (float *)vy;
+
+ block_iq2_kt * y = (block_iq2_kt *)(dptr + 1);
+
+ int best_idx[2*Q::kNg];
+
+ auto& quantizer = iq2kt_quantizer();
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_scale = 0, max_scale = 0;
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq2_kt));
+
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float ax = std::abs(xb[j]);
+ amax = std::max(amax, ax);
+ }
+ quantizer.find_best_match( amax/96.f, xb, weight, best_idx);
+ auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
+ quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg);
+ auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg);
+
+ auto idx = best_idx;
+ if (score_p > score_m) scales[ib] = dp;
+ else {
+ scales[ib] = dm; idx += Q::kNg;
+ }
+ auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto q = quantizer.values() + idx[ig]*Q::kGroupSize;
+ for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j];
+ qt += Q::kGroupSize;
+ }
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ if (!max_scale) {
+ *dptr = 0;
+ return;
+ }
+
+ float d = max_scale/iq4k_values[0];
+ float best = 0;
+ for (int itry = -9; itry <= 9; ++itry) {
+ float id = (itry + iq4k_values[0])/max_scale;
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xb = x + ibl*Q::kSuperBlockSize;
+ const float * qb = qtmp + ibl*Q::kSuperBlockSize;
+ const float * wb = all_weights + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int ls = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ float dl = iq4k_values[ls];
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float q = dl*qb[j];
+ sumqx += wb[j]*xb[j]*q;
+ sumq2 += wb[j]*q*q;
+ }
+ xb += Q::kBlockSize;
+ wb += Q::kBlockSize;
+ qb += Q::kBlockSize;
+ }
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx;
+ }
+ }
+
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]);
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+
+ *dptr = d;
+ if (!d) return;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ auto qs = (uint16_t *)y[ibl].ql;
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ int ls = iq4k_values[(y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf];
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xb, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ qs[j] = best_idx[j];
+ auto xl = xb + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize;
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ qs += Q::kNg;
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ *dptr = d;
+ if (!d) return;
+ } else {
+ break;
+ }
+
+ if (false) {
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ auto qs = (uint16_t *)y[ibl].ql;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j];
+ auto pair = quantizer.find_best_scale(xb, weight, best_idx);
+ scales[ib] = pair.first;
+ }
+ }
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]);
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+ }
+
+ }
+
+}
+}
+
+void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq2_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq2_kt * y = (block_iq2_kt *)vy;
+ quantize_row_iq2_kt_ref(x, y, k);
+}
+
+size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize);
+ std::vector<float> weights(n_per_row);
+ std::vector<float> xtmp(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
+ assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
+ const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = *dptr * QuantizerIQ2KT::kScale;
+ x = (const block_iq2_kt *)(dptr + 1);
+ auto& deq = iq2kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto yl = y + ibl*QuantizerIQ2KT::kSuperBlockSize;
+ auto yh = yl + QuantizerIQ2KT::kSuperBlockSize/2;
+ const uint16_t * ql = (const uint16_t *)x[ibl].ql;
+ const uint16_t * qh = ql + QuantizerIQ2KT::kNg*QuantizerIQ2KT::kNblock/2;
+ for (int ib = 0; ib < QuantizerIQ2KT::kNblock/2; ++ib) {
+ float sl = d * iq4k_values[x[ibl].scales[ib] & 0xf];
+ float sh = d * iq4k_values[x[ibl].scales[ib] >> 4];
+ for (int ig = 0; ig < QuantizerIQ2KT::kNg; ++ig) {
+ deq.set_values(ql[ig], yl, sl);
+ deq.set_values(qh[ig], yh, sh);
+ yl += QuantizerIQ2KT::kGroupSize;
+ yh += QuantizerIQ2KT::kGroupSize;
+ }
+ ql += QuantizerIQ2KT::kNg;
+ qh += QuantizerIQ2KT::kNg;
+ }
+ }
+}
+
+void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}
+
+namespace {
+
+using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>;
+const QuantizerIQ3KT& iq3kt_quantizer() {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unique_ptr<QuantizerIQ3KT> quantizer;
+ if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 8);
+ return *quantizer;
+}
+
+void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales,
+ float * all_weights, float * qtmp) {
+
+ constexpr float kSigmaScale = 2.0f;
+ constexpr float kStep = 8.0f;
+
+ using Q = QuantizerIQ3KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+
+ float * dptr = (float *)vy;
+
+ block_iq3_kt * y = (block_iq3_kt *)(dptr + 1);
+
+ int best_idx[2*Q::kNg];
+
+ auto& quantizer = iq3kt_quantizer();
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ float amax_row = 0;
+ for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j]));
+ if (!amax_row) {
+ *dptr = 0.f;
+ std::memset(y, 0, nblock*sizeof(block_iq3_kt));
+ return;
+ }
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_scale = 0, max_scale = 0;
+
+ float xaux[Q::kBlockSize];
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq3_kt));
+
+ auto scales = all_scales + ibl*Q::kNblock;
+ auto xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float ax = std::abs(xb[j]);
+ xaux[j] = ax;
+ amax = std::max(amax, ax);
+ }
+ scales[ib] = 0;
+ if (!amax) continue;
+
+ //quantizer.find_best_match(amax/96.f, xaux, weight, best_idx+Q::kNg);
+ //scales[ib] = quantizer.find_best_scale(xaux, weight, best_idx+Q::kNg).first;
+
+ float scale_0 = std::max(84.f, 123.f*amax/amax_row);
+ //float scale_0 = std::max(64.f, 123.f*amax/amax_row);
+ float best = 0;
+ for (int itry = -3; itry <= 3; ++itry) {
+ quantizer.find_best_match(amax/(scale_0 + kStep*itry), xaux, weight, best_idx);
+ auto [d, score] = quantizer.find_best_scale(xaux, weight, best_idx);
+ if (score > best) {
+ best = score;
+ scales[ib] = d;
+ std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int));
+ }
+ }
+
+ auto xt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto q = quantizer.values() + Q::kGroupSize*best_idx[Q::kNg+ig];
+ for (int j = 0; j < Q::kGroupSize; ++j) *xt++ = q[j];
+ }
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ GGML_ASSERT(max_scale >= 0);
+ float d = max_scale/15;
+ float best = 0;
+ for (int itry = -9; itry <= 9; ++itry) {
+ float id = (itry*0.2f + 15)/max_scale;
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xb = x + ibl*Q::kSuperBlockSize;
+ const float * qb = qtmp + ibl*Q::kSuperBlockSize;
+ const float * wb = all_weights + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int ls = nearest_int(id*scales[ib]);
+ ls = std::max(0, std::min(15, ls));
+ float dl = ls;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float q = dl*qb[j];
+ sumqx += wb[j]*std::abs(xb[j])*q;
+ sumq2 += wb[j]*q*q;
+ }
+ xb += Q::kBlockSize;
+ wb += Q::kBlockSize;
+ qb += Q::kBlockSize;
+ }
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx;
+ }
+ }
+
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = nearest_int(id*scales[ib]);
+ int ls2 = nearest_int(id*scales[ib + Q::kNblock/2]);
+ ls1 = std::max(0, std::min(15, ls1));
+ ls2 = std::max(0, std::min(15, ls2));
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+
+ *dptr = d;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ uint16_t * ql = (uint16_t *)y[ibl].ql;
+
+ std::memset(y[ibl].qh, 0, kNumGroups/2);
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ xaux[j] = std::abs(xb[j]);
+ if (xb[j] < 0) y[ibl].qh[j] |= (1 << ib);
+ }
+ int ls = (y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf;
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xaux, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ ql[ib*Q::kNg+j] = best_idx[j];
+ auto xl = xaux + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize;
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ *dptr = d;
+ if (!d) break;
+ } else {
+ break;
+ }
+ }
+}
+}
+
+void quantize_row_iq3_kt_ref(const float * x, block_iq3_kt * y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq3_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq3_kt(const float * x, void * vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq3_kt * y = (block_iq3_kt *)vy;
+ quantize_row_iq3_kt_ref(x, y, k);
+}
+
+size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ3KT::kBlockSize);
+ std::vector<float> weights(n_per_row), xtmp(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq3_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
+ using Q = QuantizerIQ3KT;
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+ assert(k % Q::kSuperBlockSize == 0);
+ const int nb = k / Q::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = *dptr * Q::kScale;
+ x = (const block_iq3_kt *)(dptr + 1);
+ auto& deq = iq3kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto yl = y + ibl*Q::kSuperBlockSize;
+ auto yh = yl + Q::kSuperBlockSize/2;
+ auto qll = (const uint16_t *)x[ibl].ql;
+ auto qlh = qll + kNumGroups/2;
+ int jj = 0;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ float sl = d * (x[ibl].scales[ib] & 0xf);
+ float sh = d * (x[ibl].scales[ib] >> 4);
+ uint8_t l_mask = 1 << ib;
+ uint8_t h_mask = l_mask << (Q::kNblock/2);
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ deq.set_values(qll[jj], yl, sl);
+ deq.set_values(qlh[jj], yh, sh);
+ for (int j = 0; j < Q::kGroupSize; ++j) {
+ if (x[ibl].qh[ig*Q::kGroupSize+j] & l_mask) yl[j] = -yl[j];
+ if (x[ibl].qh[ig*Q::kGroupSize+j] & h_mask) yh[j] = -yh[j];
+ }
+ yl += Q::kGroupSize;
+ yh += Q::kGroupSize;
+ ++jj;
+ }
+ }
+ }
+}
+
+void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}
+
+// ======================================== iq4_kt
+
+namespace{
+
+using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>;
+
+const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unique_ptr<QuantizerIQ4KT> quantizer1;
+ static std::unique_ptr<QuantizerIQ4KT> quantizer2;
+ if (with_offset) {
+ if (!quantizer2) quantizer2 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096+32768);
+ return *quantizer2;
+ }
+ if (!quantizer1) quantizer1 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096);
+ return *quantizer1;
+}
+
+void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) {
+
+ constexpr float kSigmaScale = 2.0f;
+ constexpr int kNtry = 2;
+ using Q = QuantizerIQ4KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ float * dptr = (float *)vy;
+
+ block_iq4_kt * y = (block_iq4_kt *)(dptr + 2);
+
+ auto& quantizer1 = iq4kt_quantizer();
+ auto& quantizer2 = iq4kt_quantizer(true);
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_row = 0, row_av = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ row_av += x[j];
+ amax_row = std::max(amax_row, std::abs(x[j]));
+ }
+ row_av /= n_per_row;
+ dptr[1] = row_av;
+ if (!amax_row) {
+ dptr[0] = 0.f;
+ std::memset(y, 0, nblock*sizeof(block_iq4_kt));
+ return;
+ }
+
+ int best_idx[2*Q::kNg];
+ float xaux[Q::kBlockSize];
+
+ float amax_scale = 0, max_scale = 0;
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq4_kt));
+
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ float ax = std::abs(xaux[j]);
+ amax = std::max(amax, ax);
+ }
+ if (!amax) {
+ scales[ib] = 0;
+ continue;
+ }
+ float best = 0;
+ float scale_0 = std::max(92.f, 127.f*amax/amax_row);
+ for (int itry = -kNtry; itry <= kNtry; ++itry) {
+ quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx);
+ if (score_p > best) {
+ best = score_p; scales[ib] = dp;
+ }
+ quantizer1.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dm, score_m] = quantizer1.find_best_scale(xaux, weight, best_idx);
+ if (score_m > best) {
+ best = score_m; scales[ib] = dm;
+ }
+ }
+
+ quantizer2.find_best_match(scales[ib], xaux, weight, best_idx);
+ auto [d, score] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score > best) {
+ scales[ib] = d;
+ y[ibl].qs[ib] = 1;
+ }
+ bool with_offset = false;
+ for (int itry = -kNtry; itry <= kNtry; ++itry) {
+ quantizer2.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dp, score_p] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score_p > best) {
+ best = score_p; scales[ib] = dp; with_offset = true;
+ }
+ quantizer2.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dm, score_m] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score_m > best) {
+ best = score_m; scales[ib] = dm; with_offset = true;
+ }
+ }
+ if (with_offset) y[ibl].qs[ib] = 1;
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ float d = -max_scale/64;
+
+ dptr[0] = d;
+ if (!d) return;
+
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ const float id = 1/d;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ // high 3 bits + scales
+ // each block of 32 needs 8 x 3 (high bits) + 1 x 8 (scale) = 32 bits = 1 x uint32_t
+ // we have 8 blocks
+ auto shb = y[ibl].qs; // high 3 bits + scales
+ auto ql = (uint8_t *)(shb + Q::kNblock);
+ auto qh = ql + kNumGroups;
+ std::memset(qh, 0, kNumGroups/2);
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ int ls = nearest_int(id*scales[ib]);
+ ls = std::min(ls, 63);
+ *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1);
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xaux, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ shb[ib] |= ((best_idx[j] >> 12) << (8 + 3*j));
+ ql[Q::kNg*ib + j] = best_idx[j] & 255;
+ qh[(Q::kNg*ib + j)%(kNumGroups/2)] |= ((best_idx[j] >> 8) & 0xf) << 4*((Q::kNg*ib + j)/(kNumGroups/2));
+ auto xl = xaux + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + Q::kGroupSize*best_idx[j];
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ dptr[0] = d;
+ if (!d) break;
+ } else {
+ break;
+ }
+ }
+}
+}
+
+void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq4_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq4_kt * y = (block_iq4_kt *)vy;
+ quantize_row_iq4_kt_ref(x, y, k);
+}
+
+size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ4KT::kBlockSize);
+ std::vector<float> weights(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq4_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
+ using Q = QuantizerIQ4KT;
+ assert(k % Q::kSuperBlockSize == 0);
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+ const int nb = k / Q::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = dptr[0] * Q::kScale;
+ const float row_av = dptr[1];
+ x = (const block_iq4_kt *)(dptr + 2);
+ auto& deq = iq4kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto shb = x[ibl].qs;
+ auto ql = (const uint8_t *)(shb + Q::kNblock);
+ auto qh = ql + kNumGroups;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int offset = shb[ib] & 1 ? 32768 + 4096 : 4096;
+ //auto& deq = shb[ib] & 1 ? deq2 : deq1;
+ int ls = int((shb[ib] & 0xff) >> 1) - 64;
+ float sl = d * ls;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ int jj = ib*Q::kNg+ig;
+ uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12);
+ deq.set_values(idx, y, sl, offset);
+ for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av;
+ y += Q::kGroupSize;
+ }
+ }
+ }
+}
+
+void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 9c274d4b..70918a65 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -67,6 +67,24 @@ size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq2_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq2_kt(const block_iq2_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq2_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_iq3_kt_ref(const float * GGML_RESTRICT x, block_iq3_kt * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq3_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq3_kt(const block_iq3_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq3_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_kt(const block_iq4_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_iq5_ks_ref(const float * GGML_RESTRICT x, block_iq5_ks * GGML_RESTRICT y, int64_t k);
void quantize_row_iq5_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_iq5_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);