diff options
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 13 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 1 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 36 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 32 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 21 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 91 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 219 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 5 | ||||
-rw-r--r-- | src/llama.cpp | 18 |
15 files changed, 451 insertions, 14 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 2397e202..5f599c65 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -40,6 +40,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, + { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index ff7f0064..2cb4af32 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -390,6 +390,7 @@ extern "C" { GGML_TYPE_IQ2_BN = 35, GGML_TYPE_Q8_K64 = 36, GGML_TYPE_IQ4_K = 37, + GGML_TYPE_IQ2_K = 38, GGML_TYPE_COUNT, }; @@ -437,6 +438,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_K = 31, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 755d52b9..9466dfcf 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -454,6 +454,14 @@ typedef struct { } block_iq4_k; static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding"); +typedef struct { + ggml_half d; + uint16_t extra; + uint8_t scales[QK_K/32]; + uint8_t qs[QK_K/4]; +} block_iq2_k; +static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL @@ -1890,5 +1898,10 @@ GGML_TABLE_BEGIN(int8_t, iq4k_values, 32) -123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117 GGML_TABLE_END() +GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8) + -31, -13, 1, 17, -26, -8, 6, 22 +GGML_TABLE_END() + + #endif // GGML_COMMON_IMPL #endif // GGML_COMMON_IMPL diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index cfeda744..a4c93ad6 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2754,6 +2754,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: return true; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8549c4e5..12eebb00 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -670,6 +670,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e7732cf5..6dd0fc50 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -543,6 +543,32 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_ } } +template<typename dst_t> +static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_k * x = (const block_iq2_k *) vx; + + const int tid = threadIdx.x; + int ib128 = tid/16; // 0 or 1 + int il = tid%16; // 0...15 + dst_t * y = yy + i*QK_K + 128*ib128 + 2*il; + const float d = (float)x[i].d * 1.025f; //1.0325f; + const float dl1 = d * (2*((x[i].scales[4*ib128+0] >> 4*(il/8)) & 0xf) - 15); + const float dl2 = d * (2*((x[i].scales[4*ib128+1] >> 4*(il/8)) & 0xf) - 15); + const float dl3 = d * (2*((x[i].scales[4*ib128+2] >> 4*(il/8)) & 0xf) - 15); + const float dl4 = d * (2*((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 15); + const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; + const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + } +} + + template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); @@ -678,6 +704,12 @@ static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); } +template<typename dst_t> +static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y); +} + template <typename src_t, typename dst_t> static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -744,6 +776,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -797,6 +831,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 5da32d99..b99dc245 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -25,6 +25,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 : + type == GGML_TYPE_IQ2_K ? vec_dot_iq2_k_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : nullptr; } @@ -48,6 +49,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ : + type == GGML_TYPE_IQ2_K ? VDR_IQ2_K_Q8_1_MMVQ : 1; } @@ -352,6 +354,13 @@ static void mul_mat_vec_iq4_k_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq2_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -443,6 +452,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_K: mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 9f2b2300..97a5619f 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1274,3 +1274,35 @@ static __device__ __forceinline__ float vec_dot_iq4_k_q8_1( return d * (sumi1 * ls1 + sumi2 * ls2); } +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 + +// TODO +static __device__ __forceinline__ float vec_dot_iq2_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + return 0; +// +// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx; +// const uint8_t * all_values = (const uint8_t *)iq4k_values; +// +// // iqs is 0...28 +// const int ib32 = iqs/4; +// // Why iqs/4 ? +// const int32_t * q8 = (const int *)bq8_1[ib32].qs; +// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; +// const uint16_t extra = bq4->extra >> 2*ib32; +// int v1, v2; +// int sumi1 = 0, sumi2 = 0; +// for (int j = 0; j < 4; ++j) { +// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); +// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); +// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); +// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); +// } +// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); +// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); +// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; +// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; +// return d * (sumi1 * ls1 + sumi2 * ls2); +} + diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index fef124c3..a5dbff12 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14947,6 +14947,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_IQ2_K: break; case GGML_TYPE_IQ4_K: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6bfeca1e..0881756d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -992,6 +992,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ2_K] = { + .type_name = "iq2_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_k, + .from_float = quantize_row_iq2_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_k_ref, + .vec_dot = vec_dot_iq2_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use @@ -3342,6 +3354,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; + case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; @@ -9592,6 +9605,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -9973,6 +9987,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10104,6 +10119,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13024,6 +13040,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13215,6 +13232,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13480,6 +13498,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -14072,6 +14091,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: @@ -20808,6 +20828,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1fe0af74..ad09d341 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -742,6 +742,88 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { }; +struct IQXKScales { + IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm_set1_epi8(shift)), min(_mm256_set1_epi8(min_val)) {} + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { + auto extra128 = _mm_set1_epi16(extra); + extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + extra128 = _mm_and_si128(extra128, e5); + extra128 = _mm_shuffle_epi8(extra128, eshuffle); + auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), + _mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128))); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + scales16 = MM256_SET_M128I(scales8, scales8); + scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); + scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); + } + const __m128i eshift; + const __m256i min; + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + const __m128i e5 = _mm_set1_epi8(5); + const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000); + const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404); +}; + +struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { + DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales); + //auto scales8 = make_scales(x[i].scales); + //auto extra128 = _mm_set1_epi16(x[i].extra); + //extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + //extra128 = _mm_and_si128(extra128, e5); + //extra128 = _mm_shuffle_epi8(extra128, eshuffle); + //auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), + // _mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128))); + //for (int iy = 0; iy < Q8::nrc_y; ++iy) { + // const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + // accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + //} + //scales16 = MM256_SET_M128I(scales8, scales8); + //scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); + //scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); + } + inline void prepare(const uint8_t * q2) { + bits.prepare(q2); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + return _mm_add_epi8(_mm_slli_epi16(scl, 1), m15); + } + Q2Bits bits; + IQXKScales iqxk; + + const __m512i values; + const __m128i m15 = _mm_set1_epi8(-15); + //const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + //const __m128i m15 = _mm_set1_epi8(-15); + //const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + //const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + //const __m128i e5 = _mm_set1_epi8(5); + //const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000); + //const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404); +}; + struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template <typename Q8> @@ -784,11 +866,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { auto sch = _mm_shuffle_epi8(aux, hshuff); return _mm_add_epi8(_mm_or_si128(scl, sch), m32); } - //static __m256i load_shuffle(int i) { - // static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909, - // 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d}; - // return _mm256_loadu_si256((const __m256i *)k_shuffles + i); - //} Q4Bits bits; const __m512i values; @@ -2897,6 +2974,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ4XS>(mm); break; + case GGML_TYPE_IQ2_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions<DequantizerIQ2K>(mm); + break; case GGML_TYPE_IQ4_K: assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ4K>(mm); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index e60e61a1..7722d630 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -471,8 +471,8 @@ void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const int8_t * q8 = y[ibl].qs; int32_t sum = 0; for (int ib = 0; ib < QK_K/32; ++ib) { - const int ls1 = (x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30) - 32; - const int ls2 = (x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30) - 32; + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; h >>= 4; const int8_t * values1 = iq4k_values + 16*(extra & 1); const int8_t * values2 = iq4k_values + 8*(extra & 2); @@ -698,3 +698,218 @@ size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_pe } return nrows * nblock * sizeof(block_iq4_k); } + +// +// ============================================== iq2_K +// + +namespace { + +inline int best_index_iq2nl(const int8_t * values, float x) { + int idx = x < values[1] ? 0 : x > values[2] ? 2 : 1; + return x - values[idx] < values[idx+1] - x ? idx : idx + 1; +} + +void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + + constexpr int kBlockSize = 16; + + block_iq2_k * y = (block_iq2_k *)vy; + + float scales[QK_K/kBlockSize]; + float weight[kBlockSize]; + float sumx[kBlockSize+1], sumw[kBlockSize+1]; + + std::array<std::pair<float,int>, kBlockSize> pairs; + + const int8_t * shifted_values = iq2nl_values + 4; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + + uint16_t extra = 0; + + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + for (int j = 0; j < kBlockSize; ++j) pairs[j] = {xb[j], j}; + std::sort(pairs.begin(), pairs.end()); + sumx[0] = sumw[0] = 0; + for (int j = 0; j < kBlockSize; ++j) { + int jj = pairs[j].second; + sumw[j+1] = sumw[j] + weight[jj]; + sumx[j+1] = sumx[j] + weight[jj]*xb[jj]; + } + float best = 0, d = 0; + bool is_shifted = false; + float sumqx, sumq2; + for (int i1 = 0; i1 < kBlockSize; ++i1) { + for (int i2 = i1; i2 < kBlockSize; ++i2) { + for (int i3 = i2; i3 < kBlockSize; ++i3) { + sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[0] + (sumx[i2] - sumx[i1])*iq2nl_values[1] + + (sumx[i3] - sumx[i2])*iq2nl_values[2] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[3]; + sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[0]*iq2nl_values[0] + (sumw[i2] - sumw[i1])*iq2nl_values[1]*iq2nl_values[1] + + (sumw[i3] - sumw[i2])*iq2nl_values[2]*iq2nl_values[2] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[3]*iq2nl_values[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*shifted_values[0] + (sumx[i2] - sumx[i1])*shifted_values[1] + + (sumx[i3] - sumx[i2])*shifted_values[2] + (sumx[kBlockSize] - sumx[i3])*shifted_values[3]; + sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[0]*shifted_values[0] + (sumw[i2] - sumw[i1])*shifted_values[1]*shifted_values[1] + + (sumw[i3] - sumw[i2])*shifted_values[2]*shifted_values[2] + (sumw[kBlockSize] - sumw[i3])*shifted_values[3]*shifted_values[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[3] + (sumx[i2] - sumx[i1])*iq2nl_values[2] + + (sumx[i3] - sumx[i2])*iq2nl_values[1] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[0]; + sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[3]*iq2nl_values[3] + (sumw[i2] - sumw[i1])*iq2nl_values[2]*iq2nl_values[2] + + (sumw[i3] - sumw[i2])*iq2nl_values[1]*iq2nl_values[1] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[0]*iq2nl_values[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*shifted_values[3] + (sumx[i2] - sumx[i1])*shifted_values[2] + + (sumx[i3] - sumx[i2])*shifted_values[1] + (sumx[kBlockSize] - sumx[i3])*shifted_values[0]; + sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[3]*shifted_values[3] + (sumw[i2] - sumw[i1])*shifted_values[2]*shifted_values[2] + + (sumw[i3] - sumw[i2])*shifted_values[1]*shifted_values[1] + (sumw[kBlockSize] - sumw[i3])*shifted_values[0]*shifted_values[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + } + } + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + max_abs_scale = MAX(max_abs_scale, abs_scale); + } + + if (!max_abs_scale) continue; + + float d = max_abs_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + int ls = nearest_int(0.5f*(id*scales[ib]+15)); + ls = MAX(0, MIN(15, ls)); + y[ibl].scales[ib/2] |= (ls << 4*(ib%2)); + ls = 2*ls - 15; + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values; + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq2nl(block_values, al); + qs[j] |= (ibest << 2*(ib32%4)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } +} +} + +void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_k * y = (block_iq2_k *)vy; + quantize_row_iq2_k_ref(x, y, k); +} + +size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq2_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_k); + } + return nrows * nblock * sizeof(block_iq2_k); +} + +void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + float dl1 = d * (2*(x[i].scales[ib32] & 0xf) - 15); + float dl2 = d * (2*(x[i].scales[ib32] >> 4) - 15); + const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values; + extra >>= 2; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[(qs[j+ 0] >> shift) & 3]; + y[j+16] = dl2 * values2[(qs[j+16] >> shift) & 3]; + } + y += 32; + shift += 2; + if (shift == 8) { qs += 32; shift = 0; } + } + + } + +} + +void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_k * x = (const block_iq2_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index dcc12dd2..f36eff38 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -19,6 +19,12 @@ size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index a90b56e1..3549d3f3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -168,9 +168,10 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, - LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, + LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_K = 38, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_K = 39, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index a87bfe59..3f9a211c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3762,6 +3762,7 @@ struct llama_model_loader { case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; + case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; @@ -4458,6 +4459,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; @@ -15407,7 +15409,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || - ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { new_type = GGML_TYPE_Q5_K; } else if (new_type != GGML_TYPE_Q8_0) { @@ -15464,6 +15466,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { + if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_IQ4_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q4_K; } @@ -15573,9 +15578,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (arch != LLM_ARCH_FALCON) { if (qs.model.hparams.n_expert == 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || - ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || - ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) { + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || + ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { new_type = GGML_TYPE_Q5_K; } } else { @@ -15629,7 +15635,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K) { + new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15656,6 +15662,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ1_M: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; @@ -15762,6 +15769,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; + case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break; |