diff options
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 10 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 15 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 3 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 57 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 23 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 9 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 21 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 268 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 5 | ||||
-rw-r--r-- | src/llama.cpp | 14 |
15 files changed, 416 insertions, 28 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 17e87e53..0b4c3444 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -41,6 +41,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, + { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, { "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b7585ad6..94ffae7e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -390,8 +390,9 @@ extern "C" { GGML_TYPE_IQ2_BN = 35, GGML_TYPE_Q8_K64 = 36, GGML_TYPE_IQ2_K = 37, - GGML_TYPE_IQ4_K = 38, - GGML_TYPE_IQ5_K = 39, + GGML_TYPE_IQ3_K = 38, + GGML_TYPE_IQ4_K = 39, + GGML_TYPE_IQ5_K = 40, GGML_TYPE_COUNT, }; @@ -439,8 +440,9 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_K = 30, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ4_K = 31, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ5_K = 32, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 7da27794..423797b6 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -459,6 +459,16 @@ static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K typedef struct { ggml_half d; uint16_t extra; + uint16_t scales_h; + uint8_t scales_l[QK_K/32]; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/8]; +} block_iq3_k; +static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding"); + +typedef struct { + ggml_half d; + uint16_t extra; uint8_t scales_h[QK_K/64]; uint8_t scales_l[QK_K/32]; uint8_t qs[QK_K/2]; @@ -1911,6 +1921,11 @@ GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8) -31, -13, 1, 17, -26, -8, 6, 22 GGML_TABLE_END() +GGML_TABLE_BEGIN(int8_t, iq3nl_values, 16) + -63, -40, -23, -10, 1, 13, 28, 47, + -59, -36, -19, -6, 5, 17, 32, 51, +GGML_TABLE_END() + GGML_TABLE_BEGIN(int8_t, iq4k_values, 32) -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, -123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117 diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ba9d89aa..d34aa386 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2753,9 +2753,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: - case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: return true; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 516e74d8..fbc52aa9 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -677,6 +677,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index f388e9f3..ed7e4bd0 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -595,6 +595,33 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_ } } +template<typename dst_t> +static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_k * x = (const block_iq3_k *) vx; + + const int tid = threadIdx.x; + int ib128 = tid/16; // 0 or 1 + int il = tid%16; // 0...15 + dst_t * y = yy + i*QK_K + 128*ib128 + 2*il; + const float d = (float)x[i].d * 1.01f; //1.0125f; + const uint16_t sh = x[i].scales_h >> (8*ib128 + (il/8)); + const float dl1 = d * ((2*((x[i].scales_l[4*ib128+0] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x01) ? -1 : 1)); + const float dl2 = d * ((2*((x[i].scales_l[4*ib128+1] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x04) ? -1 : 1)); + const float dl3 = d * ((2*((x[i].scales_l[4*ib128+2] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x10) ? -1 : 1)); + const float dl4 = d * ((2*((x[i].scales_l[4*ib128+3] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x40) ? -1 : 1)); + const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } +} template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { @@ -726,21 +753,27 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> -static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y); } template<typename dst_t> -static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y); } template<typename dst_t> -static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y); } template <typename src_t, typename dst_t> @@ -807,12 +840,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; - case GGML_TYPE_IQ2_K: - return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -864,12 +899,14 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; - case GGML_TYPE_IQ2_K: - return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 3c2277ed..bf7b2aa7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -245,9 +245,6 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1( return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); } -#define VDR_IQ2_K_Q8_1_MMVQ 4 -#define VDR_IQ2_K_Q8_1_MMQ 4 - static const __device__ uint32_t iq2k_table[512] = { 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, @@ -319,6 +316,9 @@ __device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int * return values[a8[0] | (a8[1] << 2) | (a8[2] << 4) | (a8[3] << 6)]; } +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 + __device__ __forceinline__ float vec_dot_iq2_k_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { @@ -378,8 +378,18 @@ __device__ __forceinline__ float vec_dot_iq2_k_q8_1( } +#define VDR_IQ3_K_Q8_1_MMVQ 4 +#define VDR_IQ3_K_Q8_1_MMQ 4 + +// TODO +__device__ __forceinline__ float vec_dot_iq3_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + return 0; + } +} // namespace + void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -387,6 +397,13 @@ void mul_mat_vec_iq2_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 14e5c1c7..9a33af0d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -4,6 +4,10 @@ void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 93c8ac29..56bf3ebe 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -432,15 +432,18 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_K: mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ5_K: mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; - case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 4b3bf361..c2c66f38 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14948,6 +14948,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; case GGML_TYPE_IQ2_K: break; + case GGML_TYPE_IQ3_K: break; case GGML_TYPE_IQ4_K: break; case GGML_TYPE_IQ5_K: break; case GGML_TYPE_Q4_0_4_4: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f873e49a..4ce9948d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -992,6 +992,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ3_K] = { + .type_name = "iq3_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_k, + .from_float = quantize_row_iq3_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_k_ref, + .vec_dot = vec_dot_iq3_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_IQ4_K] = { .type_name = "iq4_k", .blck_size = QK_K, @@ -3366,6 +3378,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; + case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; @@ -9618,6 +9631,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -10001,6 +10015,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -10134,6 +10149,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -13056,6 +13072,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -13249,6 +13266,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -13516,6 +13534,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -14110,6 +14129,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ3_S: @@ -20848,6 +20868,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 9c502f07..24076467 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -629,6 +629,274 @@ void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * } // +// ============================================== iq3_k +// +namespace { +static int8_t iq3nl_index[69] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5 +}; +static inline int best_index_iq3nl(const int8_t * values, float x) { + int index = x < values[1] ? 0 : x >= values[6] ? 6 : iq3nl_index[(int)x - values[1]]; + return x - values[index] < values[index+1] - x ? index : index+1; +} + +static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + + const int ntry = 5; + + block_iq3_k * y = (block_iq3_k *)vy; + + float scales[QK_K/16]; + float weight[16]; + + const int8_t * shifted_values = iq3nl_values + 8; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = sumx2/QK_K; + + uint16_t extra = 0; + + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < 16; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/iq3nl_values[0] : max/iq3nl_values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(iq3nl_values, al); + float q = iq3nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(iq3nl_values, -al); + q = iq3nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + bool is_shifted = false; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + iq3nl_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(iq3nl_values, al); + float q = iq3nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(iq3nl_values, -al); + q = iq3nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (d) { + const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(block_values, al); + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + } + scales[ib] = d; + + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + max_abs_scale = MAX(max_abs_scale, abs_scale); + } + + if (!max_abs_scale) continue; + + float d = max_abs_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/16; ++ib) { + int ls = nearest_int(0.5f*(id*fabsf(scales[ib])-1)); + ls = MAX(0, MIN(15, ls)); + y[ibl].scales_l[ib/2] |= (ls << 4*(ib%2)); + if (scales[ib] < 0) y[ibl].scales_h |= (1 << ib); + ls = (2*ls + 1) * (scales[ib] < 0 ? -1 : 1); + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq3nl_values; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset; + uint8_t * qh = y[ibl].qh + 32*(ib32/8) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq3nl(block_values, al); + qs[j] |= ((ibest & 3) << 2*(ib32%4)); + qh[j] |= ((ibest >> 2) << (ib32%8)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } +} + +} + +void quantize_row_iq3_k_ref(const float * x, block_iq3_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq3_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_k * y = (block_iq3_k *)vy; + quantize_row_iq3_k_ref(x, y, k); +} + +size_t quantize_iq3_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq3_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_k); + } + return nrows * nblock * sizeof(block_iq3_k); +} + +void dequantize_row_iq3_k(const block_iq3_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + + uint16_t sh = x[i].scales_h; + uint16_t extra = x[i].extra; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + float dl1 = d * ((2*(x[i].scales_l[ib32] & 0xf) + 1) * ((sh & 1) ? -1 : 1)); + float dl2 = d * ((2*(x[i].scales_l[ib32] >> 4) + 1) * ((sh & 2) ? -1 : 1)); + sh >>= 2; + const int8_t * values1 = extra & 1 ? iq3nl_values + 8 : iq3nl_values; + const int8_t * values2 = extra & 2 ? iq3nl_values + 8 : iq3nl_values; + extra >>= 2; + int shift_l = 2*(ib32%4); + int shift_h = ib32%8; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[((qs[j+ 0] >> shift_l) & 3) | (((qh[j+ 0] >> shift_h) & 1) << 2)]; + y[j+16] = dl2 * values2[((qs[j+16] >> shift_l) & 3) | (((qh[j+16] >> shift_h) & 1) << 2)]; + } + y += 32; + if (shift_l == 6) qs += 32; + } + + } +} + +void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_k * x = (const block_iq2_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; +} + +// // ============================================== iq4_K // void dequantize_row_iq4_k(const block_iq4_k * x, float * y, int64_t k) { diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index b8b03169..0295eb99 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -19,6 +19,12 @@ size_t quantize_iq2_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq3_k_ref(const float * GGML_RESTRICT x, block_iq3_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_k(const block_iq3_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq4_k_ref(const float * GGML_RESTRICT x, block_iq4_k * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index 7bccd4bb..88d82958 100644 --- a/include/llama.h +++ b/include/llama.h @@ -171,8 +171,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_K = 38, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ4_K = 39, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ5_K = 40, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_K = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_K = 40, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_K = 41, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 4e7e4a6c..2caaf7d0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3762,6 +3762,7 @@ struct llama_model_loader { case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; + case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; @@ -4460,6 +4461,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; @@ -15477,8 +15479,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; } - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) { - new_type = GGML_TYPE_Q4_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_IQ4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { new_type = GGML_TYPE_Q4_K; @@ -15578,12 +15580,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { if (arch != LLM_ARCH_FALCON) { - if (qs.model.hparams.n_expert == 8) { + if (qs.model.hparams.n_expert >= 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K) { new_type = GGML_TYPE_Q5_K; } } else { @@ -15638,7 +15640,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ5_K) { + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15666,6 +15668,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; @@ -15773,6 +15776,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; + case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; case LLAMA_FTYPE_MOSTLY_IQ5_K: default_type = GGML_TYPE_IQ5_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; |