summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h13
-rw-r--r--ggml/src/ggml-cuda.cu1
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu36
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh32
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c21
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp91
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp219
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h5
-rw-r--r--src/llama.cpp18
15 files changed, 451 insertions, 14 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 2397e202..5f599c65 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -40,6 +40,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
+ { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index ff7f0064..2cb4af32 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -390,6 +390,7 @@ extern "C" {
GGML_TYPE_IQ2_BN = 35,
GGML_TYPE_Q8_K64 = 36,
GGML_TYPE_IQ4_K = 37,
+ GGML_TYPE_IQ2_K = 38,
GGML_TYPE_COUNT,
};
@@ -437,6 +438,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_K = 31, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 755d52b9..9466dfcf 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -454,6 +454,14 @@ typedef struct {
} block_iq4_k;
static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding");
+typedef struct {
+ ggml_half d;
+ uint16_t extra;
+ uint8_t scales[QK_K/32];
+ uint8_t qs[QK_K/4];
+} block_iq2_k;
+static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding");
+
#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL
@@ -1890,5 +1898,10 @@ GGML_TABLE_BEGIN(int8_t, iq4k_values, 32)
-123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117
GGML_TABLE_END()
+GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8)
+ -31, -13, 1, 17, -26, -8, 6, 22
+GGML_TABLE_END()
+
+
#endif // GGML_COMMON_IMPL
#endif // GGML_COMMON_IMPL
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index cfeda744..a4c93ad6 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2754,6 +2754,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
return true;
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 8549c4e5..12eebb00 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -670,6 +670,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index e7732cf5..6dd0fc50 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -543,6 +543,32 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int i = blockIdx.x;
+ const block_iq2_k * x = (const block_iq2_k *) vx;
+
+ const int tid = threadIdx.x;
+ int ib128 = tid/16; // 0 or 1
+ int il = tid%16; // 0...15
+ dst_t * y = yy + i*QK_K + 128*ib128 + 2*il;
+ const float d = (float)x[i].d * 1.025f; //1.0325f;
+ const float dl1 = d * (2*((x[i].scales[4*ib128+0] >> 4*(il/8)) & 0xf) - 15);
+ const float dl2 = d * (2*((x[i].scales[4*ib128+1] >> 4*(il/8)) & 0xf) - 15);
+ const float dl3 = d * (2*((x[i].scales[4*ib128+2] >> 4*(il/8)) & 0xf) - 15);
+ const float dl4 = d * (2*((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 15);
+ const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
+ const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
+ for (int j = 0; j < 2; ++j) {
+ y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
+ y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)];
+ y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)];
+ y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)];
+ }
+}
+
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
@@ -678,6 +704,12 @@ static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
}
+template<typename dst_t>
+static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -744,6 +776,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_K:
return dequantize_row_iq4_k_cuda;
+ case GGML_TYPE_IQ2_K:
+ return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
@@ -797,6 +831,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_K:
return dequantize_row_iq4_k_cuda;
+ case GGML_TYPE_IQ2_K:
+ return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 5da32d99..b99dc245 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -25,6 +25,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 :
+ type == GGML_TYPE_IQ2_K ? vec_dot_iq2_k_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
nullptr;
}
@@ -48,6 +49,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ2_K ? VDR_IQ2_K_Q8_1_MMVQ :
1;
}
@@ -352,6 +354,13 @@ static void mul_mat_vec_iq4_k_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+static void mul_mat_vec_iq2_k_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -443,6 +452,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ4_K:
mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ2_K:
+ mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index 9f2b2300..97a5619f 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1274,3 +1274,35 @@ static __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
return d * (sumi1 * ls1 + sumi2 * ls2);
}
+#define VDR_IQ2_K_Q8_1_MMVQ 4
+#define VDR_IQ2_K_Q8_1_MMQ 4
+
+// TODO
+static __device__ __forceinline__ float vec_dot_iq2_k_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+ return 0;
+//
+// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx;
+// const uint8_t * all_values = (const uint8_t *)iq4k_values;
+//
+// // iqs is 0...28
+// const int ib32 = iqs/4;
+// // Why iqs/4 ?
+// const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
+// const uint16_t extra = bq4->extra >> 2*ib32;
+// int v1, v2;
+// int sumi1 = 0, sumi2 = 0;
+// for (int j = 0; j < 4; ++j) {
+// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
+// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
+// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
+// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
+// }
+// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
+// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
+// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
+// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
+// return d * (sumi1 * ls1 + sumi2 * ls2);
+}
+
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index fef124c3..a5dbff12 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -14947,6 +14947,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
+ case GGML_TYPE_IQ2_K: break;
case GGML_TYPE_IQ4_K: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 6bfeca1e..0881756d 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -992,6 +992,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
+ [GGML_TYPE_IQ2_K] = {
+ .type_name = "iq2_k",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_k),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_k,
+ .from_float = quantize_row_iq2_k,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_k_ref,
+ .vec_dot = vec_dot_iq2_k_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
};
// For internal test use
@@ -3342,6 +3354,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
+ case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
@@ -9592,6 +9605,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -9973,6 +9987,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -10104,6 +10119,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13024,6 +13040,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13215,6 +13232,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13480,6 +13498,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -14072,6 +14091,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K:
@@ -20808,6 +20828,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 1fe0af74..ad09d341 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -742,6 +742,88 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
+struct IQXKScales {
+ IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm_set1_epi8(shift)), min(_mm256_set1_epi8(min_val)) {}
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ auto extra128 = _mm_set1_epi16(extra);
+ extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
+ extra128 = _mm_and_si128(extra128, e5);
+ extra128 = _mm_shuffle_epi8(extra128, eshuffle);
+ auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)),
+ _mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128)));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ scales16 = MM256_SET_M128I(scales8, scales8);
+ scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
+ scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
+ }
+ const __m128i eshift;
+ const __m256i min;
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+ const __m128i e5 = _mm_set1_epi8(5);
+ const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
+ const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
+};
+
+struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
+ DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
+ //auto scales8 = make_scales(x[i].scales);
+ //auto extra128 = _mm_set1_epi16(x[i].extra);
+ //extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
+ //extra128 = _mm_and_si128(extra128, e5);
+ //extra128 = _mm_shuffle_epi8(extra128, eshuffle);
+ //auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)),
+ // _mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128)));
+ //for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ // const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
+ // accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ //}
+ //scales16 = MM256_SET_M128I(scales8, scales8);
+ //scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
+ //scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
+ }
+ inline void prepare(const uint8_t * q2) {
+ bits.prepare(q2);
+ bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m512i load_values() {
+ static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+ }
+ inline __m128i make_scales(const uint8_t * scales_l) const {
+ uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ return _mm_add_epi8(_mm_slli_epi16(scl, 1), m15);
+ }
+ Q2Bits bits;
+ IQXKScales iqxk;
+
+ const __m512i values;
+ const __m128i m15 = _mm_set1_epi8(-15);
+ //const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ //const __m128i m15 = _mm_set1_epi8(-15);
+ //const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ //const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+ //const __m128i e5 = _mm_set1_epi8(5);
+ //const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
+ //const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
+};
+
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
@@ -784,11 +866,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
auto sch = _mm_shuffle_epi8(aux, hshuff);
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
}
- //static __m256i load_shuffle(int i) {
- // static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909,
- // 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d};
- // return _mm256_loadu_si256((const __m256i *)k_shuffles + i);
- //}
Q4Bits bits;
const __m512i values;
@@ -2897,6 +2974,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4XS>(mm);
break;
+ case GGML_TYPE_IQ2_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ2K>(mm);
+ break;
case GGML_TYPE_IQ4_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4K>(mm);
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index e60e61a1..7722d630 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -471,8 +471,8 @@ void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx,
const int8_t * q8 = y[ibl].qs;
int32_t sum = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
- const int ls1 = (x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30) - 32;
- const int ls2 = (x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30) - 32;
+ const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+ const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
h >>= 4;
const int8_t * values1 = iq4k_values + 16*(extra & 1);
const int8_t * values2 = iq4k_values + 8*(extra & 2);
@@ -698,3 +698,218 @@ size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
}
return nrows * nblock * sizeof(block_iq4_k);
}
+
+//
+// ============================================== iq2_K
+//
+
+namespace {
+
+inline int best_index_iq2nl(const int8_t * values, float x) {
+ int idx = x < values[1] ? 0 : x > values[2] ? 2 : 1;
+ return x - values[idx] < values[idx+1] - x ? idx : idx + 1;
+}
+
+void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
+
+ constexpr int kBlockSize = 16;
+
+ block_iq2_k * y = (block_iq2_k *)vy;
+
+ float scales[QK_K/kBlockSize];
+ float weight[kBlockSize];
+ float sumx[kBlockSize+1], sumw[kBlockSize+1];
+
+ std::array<std::pair<float,int>, kBlockSize> pairs;
+
+ const int8_t * shifted_values = iq2nl_values + 4;
+
+ for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq2_k));
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+ const float * xbl = x + ibl*QK_K;
+ float sumx2 = 0;
+ for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j];
+ const float sigma2 = 1.5f*sumx2/QK_K;
+
+ uint16_t extra = 0;
+
+ float max_abs_scale = 0;
+
+ for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
+ const float * xb = xbl + kBlockSize*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
+ }
+ for (int j = 0; j < kBlockSize; ++j) pairs[j] = {xb[j], j};
+ std::sort(pairs.begin(), pairs.end());
+ sumx[0] = sumw[0] = 0;
+ for (int j = 0; j < kBlockSize; ++j) {
+ int jj = pairs[j].second;
+ sumw[j+1] = sumw[j] + weight[jj];
+ sumx[j+1] = sumx[j] + weight[jj]*xb[jj];
+ }
+ float best = 0, d = 0;
+ bool is_shifted = false;
+ float sumqx, sumq2;
+ for (int i1 = 0; i1 < kBlockSize; ++i1) {
+ for (int i2 = i1; i2 < kBlockSize; ++i2) {
+ for (int i3 = i2; i3 < kBlockSize; ++i3) {
+ sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[0] + (sumx[i2] - sumx[i1])*iq2nl_values[1]
+ + (sumx[i3] - sumx[i2])*iq2nl_values[2] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[3];
+ sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[0]*iq2nl_values[0] + (sumw[i2] - sumw[i1])*iq2nl_values[1]*iq2nl_values[1]
+ + (sumw[i3] - sumw[i2])*iq2nl_values[2]*iq2nl_values[2] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[3]*iq2nl_values[3];
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx; is_shifted = false;
+ }
+ sumqx = (sumx[i1] - sumx[ 0])*shifted_values[0] + (sumx[i2] - sumx[i1])*shifted_values[1]
+ + (sumx[i3] - sumx[i2])*shifted_values[2] + (sumx[kBlockSize] - sumx[i3])*shifted_values[3];
+ sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[0]*shifted_values[0] + (sumw[i2] - sumw[i1])*shifted_values[1]*shifted_values[1]
+ + (sumw[i3] - sumw[i2])*shifted_values[2]*shifted_values[2] + (sumw[kBlockSize] - sumw[i3])*shifted_values[3]*shifted_values[3];
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx; is_shifted = true;
+ }
+ sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[3] + (sumx[i2] - sumx[i1])*iq2nl_values[2]
+ + (sumx[i3] - sumx[i2])*iq2nl_values[1] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[0];
+ sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[3]*iq2nl_values[3] + (sumw[i2] - sumw[i1])*iq2nl_values[2]*iq2nl_values[2]
+ + (sumw[i3] - sumw[i2])*iq2nl_values[1]*iq2nl_values[1] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[0]*iq2nl_values[0];
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx; is_shifted = false;
+ }
+ sumqx = (sumx[i1] - sumx[ 0])*shifted_values[3] + (sumx[i2] - sumx[i1])*shifted_values[2]
+ + (sumx[i3] - sumx[i2])*shifted_values[1] + (sumx[kBlockSize] - sumx[i3])*shifted_values[0];
+ sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[3]*shifted_values[3] + (sumw[i2] - sumw[i1])*shifted_values[2]*shifted_values[2]
+ + (sumw[i3] - sumw[i2])*shifted_values[1]*shifted_values[1] + (sumw[kBlockSize] - sumw[i3])*shifted_values[0]*shifted_values[0];
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx; is_shifted = true;
+ }
+ }
+ }
+ }
+ scales[ib] = d;
+ if (is_shifted) extra |= (1 << ib);
+
+ float abs_scale = fabsf(scales[ib]);
+ max_abs_scale = MAX(max_abs_scale, abs_scale);
+ }
+
+ if (!max_abs_scale) continue;
+
+ float d = max_abs_scale/15;
+ y[ibl].d = GGML_FP32_TO_FP16(d);
+ y[ibl].extra = extra;
+ float id = 1/d;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
+ int ls = nearest_int(0.5f*(id*scales[ib]+15));
+ ls = MAX(0, MIN(15, ls));
+ y[ibl].scales[ib/2] |= (ls << 4*(ib%2));
+ ls = 2*ls - 15;
+ float dl = d * ls;
+ if (dl) {
+ const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values;
+ const float * xb = xbl + kBlockSize*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
+ }
+ float idl = 1/dl;
+ int ib32 = ib/2;
+ int offset = 16*(ib%2);
+ uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset;
+ for (int j = 0; j < 16; ++j) {
+ const float al = idl*xb[j];
+ int ibest = best_index_iq2nl(block_values, al);
+ qs[j] |= (ibest << 2*(ib32%4));
+ float w = weight[j];
+ float q = block_values[ibest]*ls;
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ }
+ }
+ }
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2);
+
+ }
+}
+}
+
+void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq2_k(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq2_k * y = (block_iq2_k *)vy;
+ quantize_row_iq2_k_ref(x, y, k);
+}
+
+size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq2_k_impl(src, (void *)qrow, n_per_row, imatrix);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq2_k);
+ }
+ return nrows * nblock * sizeof(block_iq2_k);
+}
+
+void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+
+ uint16_t extra = x[i].extra;
+
+ int shift = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ float dl1 = d * (2*(x[i].scales[ib32] & 0xf) - 15);
+ float dl2 = d * (2*(x[i].scales[ib32] >> 4) - 15);
+ const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values;
+ const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values;
+ extra >>= 2;
+ for (int j = 0; j < 16; ++j) {
+ y[j+ 0] = dl1 * values1[(qs[j+ 0] >> shift) & 3];
+ y[j+16] = dl2 * values2[(qs[j+16] >> shift) & 3];
+ }
+ y += 32;
+ shift += 2;
+ if (shift == 8) { qs += 32; shift = 0; }
+ }
+
+ }
+
+}
+
+void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+
+ const int nb = n / QK_K;
+
+ const block_iq2_k * x = (const block_iq2_k *)vx;
+ const block_q8_K * y = (const block_q8_K *)vy;
+}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index dcc12dd2..f36eff38 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -19,6 +19,12 @@ size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq2_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif
diff --git a/include/llama.h b/include/llama.h
index a90b56e1..3549d3f3 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -168,9 +168,10 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ1_BN = 36,
- LLAMA_FTYPE_MOSTLY_IQ2_BN = 37,
+ LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_K = 38, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_K = 39, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
diff --git a/src/llama.cpp b/src/llama.cpp
index a87bfe59..3f9a211c 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3762,6 +3762,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break;
+ case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
@@ -4458,6 +4459,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
@@ -15407,7 +15409,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) {
new_type = GGML_TYPE_Q5_K;
}
else if (new_type != GGML_TYPE_Q8_0) {
@@ -15464,6 +15466,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
}
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) {
+ if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_IQ4_K;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
@@ -15573,9 +15578,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (arch != LLM_ARCH_FALCON) {
if (qs.model.hparams.n_expert == 8) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
- ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
- ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) {
+ ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
+ ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) {
new_type = GGML_TYPE_Q5_K;
}
} else {
@@ -15629,7 +15635,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
- new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K) {
+ new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -15656,6 +15662,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_IQ1_M:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
+ case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
case GGML_TYPE_IQ4_K:
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
@@ -15762,6 +15769,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break;
+ case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;