diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-07-28 12:11:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-28 12:11:59 +0200 |
commit | 291066e6df5318c322a03e592483aae8820d3b19 (patch) | |
tree | 1c8cafa8d0bc73c3aa39c71ab53b53eb307d3774 | |
parent | f62615b44f7df586cb58ed9fffca59b96820117b (diff) |
IQ4_K: SOTA 4-bit quantization (#6)
* iq4_k: basics
* quantize/dequantize works
* CUDA dequantize works and one can run PPL calcs. I get
PPL = 6.5258 for LlaMA-3.1-8B, which is 1.77% above fp16.
In comparison, q4_K_S (same size) is 2.88% above fp16.
* TG on CUDA does not work. Johannes has changed the way i-quant dot
products are done, so need to sort out what he had in mind
* iqk_mul_mat is not implemented.
* iq4_k: TG now works on CUDA
* iq4_k: AVX512 implementation
For LLaMA-3.1-8B we get PP-512 = 182.6 t/s, TG-128 = 13.6 t/s,
so almost the same as q4_K_S.
* iq4_k: AVX2 implementation
For LLaMA-3.1-8B we get PP-512 = 203.1 t/s, TG-128 = 12.9 t/s
on the Ryzen-5975X.
* iq4_k: NEON implementation
For LLaMA-3.1-8B we get PP-512 = 60.7 t/s, TG-128 = 25.0 t/s
on the M2-Max. TG is on par with q4_K_S, PP is ~10% slower.
* iq4_k: Metal implementation
For LLaMA-3.1-8B we get PP-512 = 445 t/s, TG-128 = 46.3 t/s
on a 30-core M2-Max GPU. This is to be compared with (currently)
PP-512 = 460 t/s, TG-128 = 51 t/s for q4_K_S.
* iq4_k: scalar dot product
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 1 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 32 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 47 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 29 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 164 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 24 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 223 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 286 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 24 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 17 |
17 files changed, 864 insertions, 21 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 059d67a6..2397e202 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -40,6 +40,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, + { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5ed8f73d..ff7f0064 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -389,6 +389,7 @@ extern "C" { GGML_TYPE_IQ1_BN = 34, GGML_TYPE_IQ2_BN = 35, GGML_TYPE_Q8_K64 = 36, + GGML_TYPE_IQ4_K = 37, GGML_TYPE_COUNT, }; @@ -435,6 +436,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index da3f1b3c..755d52b9 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -445,6 +445,15 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +typedef struct { + ggml_half d; + uint16_t extra; + uint8_t scales_h[QK_K/64]; + uint8_t scales_l[QK_K/32]; + uint8_t qs[QK_K/2]; +} block_iq4_k; +static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL @@ -1876,5 +1885,10 @@ GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) GGML_TABLE_END() #endif +GGML_TABLE_BEGIN(int8_t, iq4k_values, 32) + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, + -123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117 +GGML_TABLE_END() + #endif // GGML_COMMON_IMPL #endif // GGML_COMMON_IMPL diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 59cf434c..cfeda744 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2753,6 +2753,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: return true; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7ea93264..8549c4e5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -670,6 +670,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> { static constexpr int qk = QK_K; static constexpr int qr = QR3_S; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 66e68a52..e7732cf5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -521,6 +521,28 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } } +template<typename dst_t> +static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_iq4_k * x = (const block_iq4_k *)vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d; + const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2); + const float d1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32); + const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); + const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); + const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } +} + template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); @@ -650,6 +672,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); } +template<typename dst_t> +static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); +} + template <typename src_t, typename dst_t> static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -714,6 +742,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -765,6 +795,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index b44000cd..5da32d99 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -24,6 +24,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : + type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : nullptr; } @@ -46,6 +47,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ : 1; } @@ -343,6 +345,13 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -431,6 +440,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_K: + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 1248eacd..9f2b2300 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1227,3 +1227,50 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); return d * sumi; } + +static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + const uint8_t * values = all_values + 16*(shift & 1); + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + values = all_values + 8*(shift & 2); + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_IQ4_K_Q8_1_MMVQ 4 +#define VDR_IQ4_K_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq4_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; + // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; + const uint16_t extra = bq4->extra >> 2*ib32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); + get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); + sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); + } + const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); + const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); + const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; + return d * (sumi1 * ls1 + sumi2 * ls2); +} + diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 388c2008..37bf8cc4 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -90,6 +90,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -120,6 +121,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -146,6 +148,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, @@ -169,6 +172,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -192,6 +196,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -559,6 +564,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, get_rows_iq4_k, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); @@ -589,6 +595,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, mul_mv_iq4_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -615,6 +622,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, mul_mv_id_iq4_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); @@ -638,6 +646,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -661,6 +670,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, mul_mm_id_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -1690,6 +1700,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } @@ -1872,6 +1883,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1916,7 +1933,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2007,6 +2024,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } @@ -2183,6 +2201,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -2238,7 +2262,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2288,6 +2312,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 67dcf53d..9bc33ceb 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3056,6 +3056,11 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +constexpr constant static float kvalues_iq4k_f[32] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f, + -123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f, +}; + kernel void kernel_cpy_f32_iq4_nl( device const float * src0, device void * dst, @@ -5187,6 +5192,111 @@ void kernel_mul_mv_iq4_xs_f32_impl( } } +void kernel_mul_mv_iq4_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_k * x = (device const block_iq4_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4k_f[tiisg]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + //float2 sumy; + //sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]); + //sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]); + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_k & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + uint16_t extra = xb.extra >> 2*ib; + threadgroup const float * values1 = shared_values + 16*(extra & 1); + threadgroup const float * values2 = shared_values + 8*(extra & 2); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2); + const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3])); + //uint16_t extra = xb.extra >> 2*ib; + //sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) + + // ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0))); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( device const void * src0, @@ -5357,6 +5467,35 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq4_k_f32")]] +kernel void kernel_mul_mv_iq4_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -5827,6 +5966,27 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 } } +template <typename type4x4> +void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + const int l = il%2; + // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1); + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f; + reg[i][0] = d * values[q8[0]]; + reg[i][1] = d * values[q8[1]]; + reg[i][2] = d * values[q8[2]]; + reg[i][3] = d * values[q8[3]]; + } +} + template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> kernel void kernel_get_rows_q( device const void * src0, @@ -6288,6 +6448,7 @@ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>; +template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_k, QK_NL, dequantize_iq4_k>; template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>; @@ -6318,6 +6479,7 @@ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; +template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_k, QK_NL, dequantize_iq4_k>; template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_bn, 4, dequantize_iq2_bn>; @@ -6350,6 +6512,7 @@ template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_bn, 4, dequantize_iq2_bn>; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>; +template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_k, QK_NL, dequantize_iq4_k>; // // matrix-vector multiplication @@ -6561,3 +6724,4 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>; +template [[host_name("kernel_mul_mv_id_iq4_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_k_f32_impl>>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index da4c9b9a..fef124c3 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14947,6 +14947,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_IQ4_K: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c3cda4c4..6bfeca1e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -11,6 +11,7 @@ #include "ggml-quants.h" #include "ggml.h" #include "ggml-aarch64.h" +#include "iqk/iqk_quantize.h" #if GGML_USE_IQK_MULMAT #include "iqk/iqk_mul_mat.h" #endif @@ -978,7 +979,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, - } + }, + [GGML_TYPE_IQ4_K] = { + .type_name = "iq4_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_k, + .from_float = quantize_row_iq4_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_k_ref, + .vec_dot = vec_dot_iq4_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use @@ -3328,6 +3341,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; @@ -9577,6 +9591,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -9957,6 +9972,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10087,6 +10103,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13006,6 +13023,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13196,6 +13214,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13460,6 +13479,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -14051,6 +14071,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: @@ -20786,6 +20807,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index bf517504..1fe0af74 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -571,8 +571,16 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { Scales8K s8k; }; +__m512i load_iq4nl_values_512() { + static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); +} + + struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { - DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} + DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template <typename Q8> inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); @@ -584,12 +592,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); } - static __m512i load_values() { - static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; - auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); - auto val256 = MM256_SET_M128I(val128, val128); - return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); - } inline void prepare(const uint8_t * q4) { bits.prepare64(q4); // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 @@ -740,6 +742,70 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { }; +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h); + auto extra128 = _mm_set1_epi16(x[i].extra); + extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + extra128 = _mm_and_si128(extra128, e4); + extra128 = _mm_shuffle_epi8(extra128, eshuffle); + auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)), + _mm256_add_epi16(_mm256_set1_epi16(-128), _mm256_cvtepi8_epi16(extra128))); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + scales16 = MM256_SET_M128I(scales8, scales8); + scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); + scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); + } + inline void prepare(const uint8_t * q4) { + bits.prepare64(q4); + // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 + // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + //static __m256i load_shuffle(int i) { + // static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909, + // 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d}; + // return _mm256_loadu_si256((const __m256i *)k_shuffles + i); + //} + + Q4Bits bits; + const __m512i values; + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000); + const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404); + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i e4 = _mm_set1_epi8(4); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + +}; + template <typename Q8> inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); @@ -933,8 +999,14 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { Scales8K s8k; }; +__m256i load_iq4nl_values() { + static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + return MM256_SET_M128I(val128, val128); +} + struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { - DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} + DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {} template <typename Q8> inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); @@ -950,18 +1022,59 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); } - static __m256i load_values() { - static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; - auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); - return MM256_SET_M128I(val128, val128); - } - Q4Bits bits; Scales8K s8k; ScaleIQ4XS siq4; const __m256i values; }; +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h); + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)); + auto extra128 = _mm_set1_epi16(x[i].extra); + extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + extra128 = _mm_and_si128(extra128, e4); + extra128 = _mm_shuffle_epi8(extra128, eshuffle); + auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(_mm256_set1_epi16(-128), _mm256_cvtepi8_epi16(extra128))); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + prepare_scales_16(scales16, scales); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + + Q4Bits bits; + const __m256i values; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i e4 = _mm_set1_epi8(4); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + +}; + struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> { DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template <typename Q8> @@ -2696,7 +2809,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { #else if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> || std::is_same_v<Dequantizer, DequantizerQ3K> || - std::is_same_v<Dequantizer, DequantizerQ6K>) { + std::is_same_v<Dequantizer, DequantizerQ6K> || + std::is_same_v<Dequantizer, DequantizerIQ4K>) { m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>; m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>; m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>; @@ -2783,6 +2897,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ4XS>(mm); break; + case GGML_TYPE_IQ4_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions<DequantizerIQ4K>(mm); + break; case GGML_TYPE_IQ3_S: assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ3S>(mm); @@ -3345,6 +3463,78 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { // ============================= i-quants +inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { + int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))}; + return make_wider(scales16); +} + +struct Scale16Extra { + template <typename Q8> + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const uint8_t * scales_l, const uint8_t * scales_h, const Q8& q8, float32x4_t * acc) { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, vreinterpretq_u8_u32(hshuff))); + scales8 = vaddq_s8(vqtbl1q_s8(scales8, vreinterpretq_u8_u32(hshuff)), vdupq_n_s8(-32)); + return new_block(i, d, extra, val, scales8, q8, acc); + } + inline static uint8x16_t get_extra(uint16_t extra) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + return vqtbl1q_u8(e8, eshuff); + } + template <typename Q8> + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff); + int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)), + vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))}; + accum_mins_16(extra16, q8, acc, i, d); + return make_wider_8(scales8); + } + + constexpr static uint32x4_t hshuff = {0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}; + constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040}; + constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09}; +}; + +// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because +// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to +// be signed or unsigned. As the Q8_K quants are signed, we need to have the +// iq4_s quants also signed. We can only use unsigned values in k-quants +// because they are all within the valid int8_t range. +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + inline void new_row(int ix) { x = (const block_iq4_k *)((const char *)vx + bx*ix); } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, x[i].scales_l, x[i].scales_h, q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + + Q4bits bits; + const int16x8_t values; + + float d; +}; + struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { static int8x16_t load_values() { @@ -4671,6 +4861,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_XS: MulMat::set_functions<DequantizerIQ4XS>(m); break; + case GGML_TYPE_IQ4_K: + MulMat::set_functions<DequantizerIQ4K>(m); + break; case GGML_TYPE_IQ2_XXS: MulMat::set_functions<DequantizerIQ2XXS>(m); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 8f541565..e60e61a1 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -11,6 +11,7 @@ #include "ggml-impl.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" +#include "iqk_quantize.h" #include <vector> #include <utility> @@ -412,3 +413,288 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) { quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k); } +// +// ============================================== iq4_K +// +void dequantize_row_iq4_k(const block_iq4_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + + uint16_t extra = x[i].extra; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2); + const float dl1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32); + const float dl2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); + const int8_t * values1 = extra & 1 ? iq4k_values + 16 : iq4k_values; + const int8_t * values2 = extra & 2 ? iq4k_values + 16 : iq4k_values; + extra >>= 2; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[qs[j] & 0xf]; + y[j+16] = dl2 * values2[qs[j] >> 4]; + } + y += 32; + qs += 16; + } + } +} + +void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq4_k * x = (const block_iq4_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t extra = x[ibl].extra; + uint32_t h = *((const uint32_t *)x[ibl].scales_h); + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + int32_t sum = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls1 = (x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30) - 32; + const int ls2 = (x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30) - 32; + h >>= 4; + const int8_t * values1 = iq4k_values + 16*(extra & 1); + const int8_t * values2 = iq4k_values + 8*(extra & 2); + extra >>= 2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * values1[qs[j] & 0xf]; + sumi2 += q8[j+16] * values2[qs[j] >> 4]; + } + sum += ls1*sumi1 + ls2*sumi2; + qs += 16; + q8 += 32; + } + sumf += d4d8 * sum; + } + *s = sumf; + +} + +namespace { +const int8_t iq4nl_index[241] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14 +}; +inline int best_index_iq4nl(const int8_t * values, float x) { + if (x <= values[ 0]) return 0; + if (x >= values[15]) return 15; + int index = iq4nl_index[(int)x - values[0]]; + return x - values[index] < values[index+1] - x ? index : index + 1; +} + +static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int block_size, const float * x, + block_iq4_k * y, + float * scales, float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + GGML_ASSERT(super_block_size == 256 && block_size == 16); + + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + + memset(y, 0, sizeof(block_iq4_k)); + y->d = GGML_FP32_TO_FP16(0.f); + + uint16_t * scales_h = (uint16_t *)y->scales_h; + + const int8_t * shifted_values = values + 16; + + float max_scale = 0, amax_scale = 0; + uint16_t extra = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + bool is_shifted = false; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (is_shifted) extra |= (1 << ib); + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; + } + } + float d = -max_scale/32; + y->d = GGML_FP32_TO_FP16(d); + y->extra = extra; + float id = d ? 1/d : 0.f; + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const int8_t * block_values = extra & (1 << ib) ? shifted_values : values; + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_iq4nl(block_values, idl*xb[j]); + float w = weight[j]; + float q = block_values[Lb[j]]*l; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) y->scales_l[ib/2] = l_l; + else y->scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + if (sumq2 > 0) y->d = GGML_FP32_TO_FP16(sumqx/sumq2); + + for (int i = 0; i < super_block_size/32; ++i) { + for (int j = 0; j < 16; ++j) { + y->qs[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} + +} + +void quantize_row_iq4_k_ref(const float * x, block_iq4_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_k * y = (block_iq4_k *)vy; + quantize_row_iq4_k_ref(x, y, k); +} + +size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[16]; + float scales[QK_K/16]; + for (int64_t row = 0; row < nrows; ++row) { + block_iq4_k * iq4 = (block_iq4_k *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = imatrix ? imatrix + QK_K*ibl : NULL; + quantize_row_iq4_k_impl_bs16(QK_K, 16, src + QK_K*ibl, iq4 + ibl, + scales, weight, L, iq4k_values, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_k); + } + return nrows * nblock * sizeof(block_iq4_k); +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h new file mode 100644 index 00000000..dcc12dd2 --- /dev/null +++ b/ggml/src/iqk/iqk_quantize.h @@ -0,0 +1,24 @@ +#pragma once + +#include <stdint.h> +#include <stddef.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#ifdef __cplusplus +#define GGML_RESTRICT +extern "C" { +#else +#define GGML_RESTRICT restrict +#endif + +void quantize_row_iq4_k_ref(const float * GGML_RESTRICT x, block_iq4_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +#ifdef __cplusplus +} +#endif diff --git a/include/llama.h b/include/llama.h index 246fdf32..a90b56e1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -170,6 +170,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, + LLAMA_FTYPE_MOSTLY_IQ4_K = 38, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index eecfccbd..a87bfe59 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3761,6 +3761,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; @@ -4456,6 +4457,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; @@ -15478,7 +15480,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q5_K; } else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -15495,6 +15497,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; } + else if (qs.model.hparams.n_gqa() >= 4) { + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S ) new_type = GGML_TYPE_Q4_K; + else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ4_K) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; + } ++qs.i_attention_wv; } else if (name.find("attn_k.weight") != std::string::npos) { if (qs.model.hparams.n_expert == 8) { @@ -15566,7 +15575,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) { + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) { new_type = GGML_TYPE_Q5_K; } } else { @@ -15620,7 +15629,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M) { + new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15648,6 +15657,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_IQ4_K: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; @@ -15751,6 +15761,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break; |