summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h14
-rw-r--r--ggml/src/ggml-cuda.cu1
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu32
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh47
-rw-r--r--ggml/src/ggml-metal.m29
-rw-r--r--ggml/src/ggml-metal.metal164
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c24
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp223
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp286
-rw-r--r--ggml/src/iqk/iqk_quantize.h24
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp17
17 files changed, 864 insertions, 21 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 059d67a6..2397e202 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -40,6 +40,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
+ { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 5ed8f73d..ff7f0064 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -389,6 +389,7 @@ extern "C" {
GGML_TYPE_IQ1_BN = 34,
GGML_TYPE_IQ2_BN = 35,
GGML_TYPE_Q8_K64 = 36,
+ GGML_TYPE_IQ4_K = 37,
GGML_TYPE_COUNT,
};
@@ -435,6 +436,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index da3f1b3c..755d52b9 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -445,6 +445,15 @@ typedef struct {
} block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+typedef struct {
+ ggml_half d;
+ uint16_t extra;
+ uint8_t scales_h[QK_K/64];
+ uint8_t scales_l[QK_K/32];
+ uint8_t qs[QK_K/2];
+} block_iq4_k;
+static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding");
+
#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL
@@ -1876,5 +1885,10 @@ GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
GGML_TABLE_END()
#endif
+GGML_TABLE_BEGIN(int8_t, iq4k_values, 32)
+ -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
+ -123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117
+GGML_TABLE_END()
+
#endif // GGML_COMMON_IMPL
#endif // GGML_COMMON_IMPL
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 59cf434c..cfeda744 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2753,6 +2753,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
return true;
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 7ea93264..8549c4e5 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -670,6 +670,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 66e68a52..e7732cf5 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -521,6 +521,28 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const int64_t i = blockIdx.x;
+ const block_iq4_k * x = (const block_iq4_k *)vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
+ const float d = (float)x[i].d;
+ const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2);
+ const float d1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32);
+ const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32);
+ const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1);
+ const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d1 * values1[q4[j] & 0xf];
+ y[j+16] = d2 * values2[q4[j] >> 4];
+ }
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
@@ -650,6 +672,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
+template<typename dst_t>
+static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -714,6 +742,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ4_K:
+ return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
@@ -765,6 +795,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ4_K:
+ return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index b44000cd..5da32d99 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -24,6 +24,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+ type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
nullptr;
}
@@ -46,6 +47,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ :
1;
}
@@ -343,6 +345,13 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+static void mul_mat_vec_iq4_k_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -431,6 +440,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ4_K:
+ mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index 1248eacd..9f2b2300 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1227,3 +1227,50 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
return d * sumi;
}
+
+static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
+ int & val1, int & val2) {
+
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
+ aux32 = q4 & 0x0f0f0f0f;
+ const uint8_t * values = all_values + 16*(shift & 1);
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
+ val1 = v1 | (v2 << 16);
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
+ values = all_values + 8*(shift & 2);
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
+ val2 = v1 | (v2 << 16);
+}
+
+#define VDR_IQ4_K_Q8_1_MMVQ 4
+#define VDR_IQ4_K_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx;
+ const uint8_t * all_values = (const uint8_t *)iq4k_values;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4;
+ // Why iqs/4 ?
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
+ const uint16_t extra = bq4->extra >> 2*ib32;
+ int v1, v2;
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < 4; ++j) {
+ const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
+ get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
+ sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
+ sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
+ }
+ const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
+ const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
+ const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
+ const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
+ return d * (sumi1 * ls1 + sumi2 * ls2);
+}
+
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 388c2008..37bf8cc4 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -90,6 +90,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -120,6 +121,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -146,6 +148,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -169,6 +172,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -192,6 +196,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
@@ -559,6 +564,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, get_rows_iq4_k, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -589,6 +595,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, mul_mv_iq4_k_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -615,6 +622,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, mul_mv_id_iq4_k_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -638,6 +646,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -661,6 +670,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, mul_mm_id_iq4_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
@@ -1690,6 +1700,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
@@ -1872,6 +1883,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32].pipeline;
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1916,7 +1933,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K) {
const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2007,6 +2024,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
@@ -2183,6 +2201,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline;
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -2238,7 +2262,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K) {
const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2288,6 +2312,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 67dcf53d..9bc33ceb 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -3056,6 +3056,11 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
+constexpr constant static float kvalues_iq4k_f[32] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f,
+ -123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f,
+};
+
kernel void kernel_cpy_f32_iq4_nl(
device const float * src0,
device void * dst,
@@ -5187,6 +5192,111 @@ void kernel_mul_mv_iq4_xs_f32_impl(
}
}
+void kernel_mul_mv_iq4_k_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq4_k * x = (device const block_iq4_k *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/16; // 0 or 1
+ const int it = tiisg%16; // 0...15
+ const int ib = it/2;
+ const int il = it%2;
+
+ shared_values[tiisg] = kvalues_iq4k_f[tiisg];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ for (int ibl = ix; ibl < nb; ibl += 2) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+ //float2 sumy;
+ //sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]);
+ //sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]);
+
+ for (int row = 0; row < 2; ++row) {
+
+ device const block_iq4_k & xb = x[row*nb + ibl];
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+ uint16_t extra = xb.extra >> 2*ib;
+ threadgroup const float * values1 = shared_values + 16*(extra & 1);
+ threadgroup const float * values2 = shared_values + 8*(extra & 2);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = q4[0] & 0x0f0f0f0f;
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+ qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
+ qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = q4[1] & 0x0f0f0f0f;
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+ qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
+ qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2);
+ const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+ const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
+ sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]));
+ //uint16_t extra = xb.extra >> 2*ib;
+ //sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) +
+ // ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0)));
+
+ }
+
+ yb += 2 * QK_K;
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32(
device const void * src0,
@@ -5357,6 +5467,35 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_iq4_k_f32")]]
+kernel void kernel_mul_mv_iq4_k_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
@@ -5827,6 +5966,27 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
}
}
+template <typename type4x4>
+void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ const int l = il%2;
+ // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+ const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4);
+ const float d = (float)xb->d * (ls - 32);
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1);
+ for (int i = 0; i < 4; ++i) {
+ aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f;
+ reg[i][0] = d * values[q8[0]];
+ reg[i][1] = d * values[q8[1]];
+ reg[i][2] = d * values[q8[2]];
+ reg[i][3] = d * values[q8[3]];
+ }
+}
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
device const void * src0,
@@ -6288,6 +6448,7 @@ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_k, QK_NL, dequantize_iq4_k>;
template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>;
@@ -6318,6 +6479,7 @@ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_k, QK_NL, dequantize_iq4_k>;
template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_bn, 4, dequantize_iq2_bn>;
@@ -6350,6 +6512,7 @@ template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_bn, 4, dequantize_iq2_bn>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_k, QK_NL, dequantize_iq4_k>;
//
// matrix-vector multiplication
@@ -6561,3 +6724,4 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_k_f32_impl>>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index da4c9b9a..fef124c3 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -14947,6 +14947,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
+ case GGML_TYPE_IQ4_K: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index c3cda4c4..6bfeca1e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -11,6 +11,7 @@
#include "ggml-quants.h"
#include "ggml.h"
#include "ggml-aarch64.h"
+#include "iqk/iqk_quantize.h"
#if GGML_USE_IQK_MULMAT
#include "iqk/iqk_mul_mat.h"
#endif
@@ -978,7 +979,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.ncols = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0,
- }
+ },
+ [GGML_TYPE_IQ4_K] = {
+ .type_name = "iq4_k",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_k),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_k,
+ .from_float = quantize_row_iq4_k,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_k_ref,
+ .vec_dot = vec_dot_iq4_k_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
};
// For internal test use
@@ -3328,6 +3341,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
+ case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
@@ -9577,6 +9591,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -9957,6 +9972,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -10087,6 +10103,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13006,6 +13023,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13196,6 +13214,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13460,6 +13479,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -14051,6 +14071,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K:
@@ -20786,6 +20807,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index bf517504..1fe0af74 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -571,8 +571,16 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
Scales8K s8k;
};
+__m512i load_iq4nl_values_512() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+}
+
+
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
- DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
@@ -584,12 +592,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
}
- static __m512i load_values() {
- static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
- auto val256 = MM256_SET_M128I(val128, val128);
- return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
- }
inline void prepare(const uint8_t * q4) {
bits.prepare64(q4);
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
@@ -740,6 +742,70 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
+ auto extra128 = _mm_set1_epi16(x[i].extra);
+ extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
+ extra128 = _mm_and_si128(extra128, e4);
+ extra128 = _mm_shuffle_epi8(extra128, eshuffle);
+ auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)),
+ _mm256_add_epi16(_mm256_set1_epi16(-128), _mm256_cvtepi8_epi16(extra128)));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ scales16 = MM256_SET_M128I(scales8, scales8);
+ scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
+ scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
+ }
+ inline void prepare(const uint8_t * q4) {
+ bits.prepare64(q4);
+ // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
+ // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, hshuff);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+ //static __m256i load_shuffle(int i) {
+ // static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909,
+ // 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d};
+ // return _mm256_loadu_si256((const __m256i *)k_shuffles + i);
+ //}
+
+ Q4Bits bits;
+ const __m512i values;
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
+ const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i e4 = _mm_set1_epi8(4);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+
+};
+
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
@@ -933,8 +999,14 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
Scales8K s8k;
};
+__m256i load_iq4nl_values() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+ return MM256_SET_M128I(val128, val128);
+}
+
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
- DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
d = GGML_FP16_TO_FP32(x[i].d);
@@ -950,18 +1022,59 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
}
- static __m256i load_values() {
- static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
- return MM256_SET_M128I(val128, val128);
- }
-
Q4Bits bits;
Scales8K s8k;
ScaleIQ4XS siq4;
const __m256i values;
};
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
+ auto extra128 = _mm_set1_epi16(x[i].extra);
+ extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
+ extra128 = _mm_and_si128(extra128, e4);
+ extra128 = _mm_shuffle_epi8(extra128, eshuffle);
+ auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(_mm256_set1_epi16(-128), _mm256_cvtepi8_epi16(extra128)));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ prepare_scales_16(scales16, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, hshuff);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+
+ Q4Bits bits;
+ const __m256i values;
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i e4 = _mm_set1_epi8(4);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+
+};
+
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
@@ -2696,7 +2809,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
#else
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
- std::is_same_v<Dequantizer, DequantizerQ6K>) {
+ std::is_same_v<Dequantizer, DequantizerQ6K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ4K>) {
m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
@@ -2783,6 +2897,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4XS>(mm);
break;
+ case GGML_TYPE_IQ4_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ4K>(mm);
+ break;
case GGML_TYPE_IQ3_S:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ3S>(mm);
@@ -3345,6 +3463,78 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
// ============================= i-quants
+inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
+ int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))};
+ return make_wider(scales16);
+}
+
+struct Scale16Extra {
+ template <typename Q8>
+ static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
+ const uint8_t * scales_l, const uint8_t * scales_h, const Q8& q8, float32x4_t * acc) {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ const uint32_t * aux32 = (const uint32_t *)scales_h;
+ uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
+ uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
+ int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, vreinterpretq_u8_u32(hshuff)));
+ scales8 = vaddq_s8(vqtbl1q_s8(scales8, vreinterpretq_u8_u32(hshuff)), vdupq_n_s8(-32));
+ return new_block(i, d, extra, val, scales8, q8, acc);
+ }
+ inline static uint8x16_t get_extra(uint16_t extra) {
+ uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
+ e8 = vceqq_u8(vandq_u8(e8, emask), emask);
+ return vqtbl1q_u8(e8, eshuff);
+ }
+ template <typename Q8>
+ static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
+ const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) {
+ uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
+ e8 = vceqq_u8(vandq_u8(e8, emask), emask);
+ e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff);
+ int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)),
+ vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))};
+ accum_mins_16(extra16, q8, acc, i, d);
+ return make_wider_8(scales8);
+ }
+
+ constexpr static uint32x4_t hshuff = {0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06};
+ constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040};
+ constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09};
+};
+
+// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because
+// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to
+// be signed or unsigned. As the Q8_K quants are signed, we need to have the
+// iq4_s quants also signed. We can only use unsigned values in k-quants
+// because they are all within the valid int8_t range.
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ inline void new_row(int ix) { x = (const block_iq4_k *)((const char *)vx + bx*ix); }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 4, x[i].scales_l, x[i].scales_h, q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs+64*j);
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
+ }
+ }
+
+ Q4bits bits;
+ const int16x8_t values;
+
+ float d;
+};
+
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
@@ -4671,6 +4861,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ4_XS:
MulMat::set_functions<DequantizerIQ4XS>(m);
break;
+ case GGML_TYPE_IQ4_K:
+ MulMat::set_functions<DequantizerIQ4K>(m);
+ break;
case GGML_TYPE_IQ2_XXS:
MulMat::set_functions<DequantizerIQ2XXS>(m);
break;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 8f541565..e60e61a1 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -11,6 +11,7 @@
#include "ggml-impl.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
+#include "iqk_quantize.h"
#include <vector>
#include <utility>
@@ -412,3 +413,288 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
}
+//
+// ============================================== iq4_K
+//
+void dequantize_row_iq4_k(const block_iq4_k * x, float * y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const uint8_t * qs = x[i].qs;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ uint16_t extra = x[i].extra;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2);
+ const float dl1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32);
+ const float dl2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32);
+ const int8_t * values1 = extra & 1 ? iq4k_values + 16 : iq4k_values;
+ const int8_t * values2 = extra & 2 ? iq4k_values + 16 : iq4k_values;
+ extra >>= 2;
+ for (int j = 0; j < 16; ++j) {
+ y[j+ 0] = dl1 * values1[qs[j] & 0xf];
+ y[j+16] = dl2 * values2[qs[j] >> 4];
+ }
+ y += 32;
+ qs += 16;
+ }
+ }
+}
+
+void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+
+ const int nb = n / QK_K;
+
+ const block_iq4_k * x = (const block_iq4_k *)vx;
+ const block_q8_K * y = (const block_q8_K *)vy;
+
+ float sumf = 0;
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
+ uint16_t extra = x[ibl].extra;
+ uint32_t h = *((const uint32_t *)x[ibl].scales_h);
+ const uint8_t * qs = x[ibl].qs;
+ const int8_t * q8 = y[ibl].qs;
+ int32_t sum = 0;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const int ls1 = (x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30) - 32;
+ const int ls2 = (x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30) - 32;
+ h >>= 4;
+ const int8_t * values1 = iq4k_values + 16*(extra & 1);
+ const int8_t * values2 = iq4k_values + 8*(extra & 2);
+ extra >>= 2;
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < 16; ++j) {
+ sumi1 += q8[j+ 0] * values1[qs[j] & 0xf];
+ sumi2 += q8[j+16] * values2[qs[j] >> 4];
+ }
+ sum += ls1*sumi1 + ls2*sumi2;
+ qs += 16;
+ q8 += 32;
+ }
+ sumf += d4d8 * sum;
+ }
+ *s = sumf;
+
+}
+
+namespace {
+const int8_t iq4nl_index[241] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5,
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
+ 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
+ 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14,
+ 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14
+};
+inline int best_index_iq4nl(const int8_t * values, float x) {
+ if (x <= values[ 0]) return 0;
+ if (x >= values[15]) return 15;
+ int index = iq4nl_index[(int)x - values[0]];
+ return x - values[index] < values[index+1] - x ? index : index + 1;
+}
+
+static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int block_size, const float * x,
+ block_iq4_k * y,
+ float * scales, float * weight, uint8_t * L,
+ const int8_t * values,
+ const float * quant_weights,
+ const int ntry) {
+
+ GGML_ASSERT(super_block_size == 256 && block_size == 16);
+
+ float sigma2 = 0;
+ for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
+ sigma2 *= 2.f/super_block_size;
+
+ memset(y, 0, sizeof(block_iq4_k));
+ y->d = GGML_FP32_TO_FP16(0.f);
+
+ uint16_t * scales_h = (uint16_t *)y->scales_h;
+
+ const int8_t * shifted_values = values + 16;
+
+ float max_scale = 0, amax_scale = 0;
+ uint16_t extra = 0;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ const float * xb = x + ib*block_size;
+ if (quant_weights) {
+ const float * qw = quant_weights + ib*block_size;
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+ }
+ float amax = 0, max = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float ax = fabsf(xb[j]);
+ if (ax > amax) {
+ amax = ax; max = xb[j];
+ }
+ }
+ if (!amax) {
+ scales[ib] = 0;
+ continue;
+ }
+ float d = ntry > 0 ? -max/values[0] : max/values[0];
+ float id = 1/d;
+ float sumqx_p = 0, sumq2_p = 0;
+ float sumqx_m = 0, sumq2_m = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float w = weight[j];
+ float al = id*xb[j];
+ int l = best_index_iq4nl(values, al);
+ float q = values[l];
+ sumqx_p += w*q*xb[j];
+ sumq2_p += w*q*q;
+ l = best_index_iq4nl(values, -al);
+ q = values[l];
+ sumqx_m += w*q*xb[j];
+ sumq2_m += w*q*q;
+ }
+ d = sumqx_p/sumq2_p;
+ bool is_shifted = false;
+ float best = d*sumqx_p;
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d*sumqx_m;
+ }
+ for (int itry = -ntry; itry <= ntry; ++itry) {
+ id = (itry + values[0])/max;
+ sumqx_p = sumq2_p = 0;
+ sumqx_m = sumq2_m = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float w = weight[j];
+ float al = id*xb[j];
+ int l = best_index_iq4nl(values, al);
+ float q = values[l];
+ sumqx_p += w*q*xb[j];
+ sumq2_p += w*q*q;
+ l = best_index_iq4nl(values, -al);
+ q = values[l];
+ sumqx_m += w*q*xb[j];
+ sumq2_m += w*q*q;
+ }
+ if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
+ d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false;
+ }
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
+ }
+ id = (itry + shifted_values[0])/max;
+ sumqx_p = sumq2_p = 0;
+ sumqx_m = sumq2_m = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float w = weight[j];
+ float al = id*xb[j];
+ int l = best_index_iq4nl(shifted_values, al);
+ float q = shifted_values[l];
+ sumqx_p += w*q*xb[j];
+ sumq2_p += w*q*q;
+ l = best_index_iq4nl(shifted_values, -al);
+ q = shifted_values[l];
+ sumqx_m += w*q*xb[j];
+ sumq2_m += w*q*q;
+ }
+ if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
+ d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true;
+ }
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
+ }
+ }
+ if (is_shifted) extra |= (1 << ib);
+ scales[ib] = d;
+ float abs_d = fabsf(d);
+ if (abs_d > amax_scale) {
+ amax_scale = abs_d; max_scale = d;
+ }
+ }
+ float d = -max_scale/32;
+ y->d = GGML_FP32_TO_FP16(d);
+ y->extra = extra;
+ float id = d ? 1/d : 0.f;
+ float sumqx = 0, sumq2 = 0;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ const int8_t * block_values = extra & (1 << ib) ? shifted_values : values;
+ int l = nearest_int(id*scales[ib]);
+ l = MAX(-32, MIN(31, l));
+ float dl = d * l;
+ float idl = dl ? 1/dl : 0.f;
+ uint8_t * Lb = L + ib*block_size;
+ const float * xb = x + ib*block_size;
+ if (quant_weights) {
+ const float * qw = quant_weights + ib*block_size;
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+ }
+ for (int j = 0; j < block_size; ++j) {
+ Lb[j] = best_index_iq4nl(block_values, idl*xb[j]);
+ float w = weight[j];
+ float q = block_values[Lb[j]]*l;
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ }
+ l += 32;
+ uint8_t l_l = l & 0xf;
+ uint8_t l_h = l >> 4;
+ if (ib%2 == 0) y->scales_l[ib/2] = l_l;
+ else y->scales_l[ib/2] |= (l_l << 4);
+ scales_h[ib/8] |= (l_h << 2*(ib%8));
+ }
+ if (sumq2 > 0) y->d = GGML_FP32_TO_FP16(sumqx/sumq2);
+
+ for (int i = 0; i < super_block_size/32; ++i) {
+ for (int j = 0; j < 16; ++j) {
+ y->qs[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
+ }
+ }
+}
+
+}
+
+void quantize_row_iq4_k_ref(const float * x, block_iq4_k * y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq4_k(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq4_k(const float * x, void * vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq4_k * y = (block_iq4_k *)vy;
+ quantize_row_iq4_k_ref(x, y, k);
+}
+
+size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ uint8_t L[QK_K];
+ float weight[16];
+ float scales[QK_K/16];
+ for (int64_t row = 0; row < nrows; ++row) {
+ block_iq4_k * iq4 = (block_iq4_k *)qrow;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * qw = imatrix ? imatrix + QK_K*ibl : NULL;
+ quantize_row_iq4_k_impl_bs16(QK_K, 16, src + QK_K*ibl, iq4 + ibl,
+ scales, weight, L, iq4k_values, qw, 7);
+ }
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq4_k);
+ }
+ return nrows * nblock * sizeof(block_iq4_k);
+}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
new file mode 100644
index 00000000..dcc12dd2
--- /dev/null
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include <stdint.h>
+#include <stddef.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#ifdef __cplusplus
+#define GGML_RESTRICT
+extern "C" {
+#else
+#define GGML_RESTRICT restrict
+#endif
+
+void quantize_row_iq4_k_ref(const float * GGML_RESTRICT x, block_iq4_k * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/include/llama.h b/include/llama.h
index 246fdf32..a90b56e1 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -170,6 +170,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_BN = 36,
LLAMA_FTYPE_MOSTLY_IQ2_BN = 37,
+ LLAMA_FTYPE_MOSTLY_IQ4_K = 38, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
diff --git a/src/llama.cpp b/src/llama.cpp
index eecfccbd..a87bfe59 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3761,6 +3761,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
+ case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
@@ -4456,6 +4457,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
@@ -15478,7 +15480,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
- else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q5_K;
}
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
@@ -15495,6 +15497,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
}
+ else if (qs.model.hparams.n_gqa() >= 4) {
+ if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
+ else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S ) new_type = GGML_TYPE_Q4_K;
+ else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ4_K) new_type = GGML_TYPE_Q5_K;
+ else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K;
+ else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K;
+ }
++qs.i_attention_wv;
} else if (name.find("attn_k.weight") != std::string::npos) {
if (qs.model.hparams.n_expert == 8) {
@@ -15566,7 +15575,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) {
new_type = GGML_TYPE_Q5_K;
}
} else {
@@ -15620,7 +15629,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
- new_type == GGML_TYPE_IQ1_M) {
+ new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -15648,6 +15657,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
+ case GGML_TYPE_IQ4_K:
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
@@ -15751,6 +15761,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
+ case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;