diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-02 15:22:13 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-02 15:22:13 +0300 |
commit | cce49832c1b81b4e535e78ff308417ef3a386b18 (patch) | |
tree | 33b10f9344f4656d58cd3ea068233ba75888498d | |
parent | d6909ed6f00f91f20c9ef628085a1a1a6a55c453 (diff) |
Adding Q6_0 (#77)
* Adding q6_0 - basics + AVX2/Zen4 working
* Adding q6_0: CUDA dequantize works, but not mmvq
* Adding q6_0: CUDA mmvq works
* Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache
* Add q6_0 to CPU flash attention
Disappointing result: for LlaMA-3.2-1B, q6_0 K- and V-cache
gives about the same PPL as q8_0 K-cache and q4_0 V-cache,
while needing the exact same RAM.
I.e., what was the point?
* q6_0: slightly better kv-cache result
Better than q8_0+q4_0, but not as good as q8_0+iq4_nl
* q6_0: works on ARM_NEON
* q6_0: dequantize works on Metal, but not vector dot product
* q6_0: it now works on Metal
Outperforms q5_0 by a significant margin. E.g.
| model | size | params | backend | ngl | threads | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | ---------------: |
| llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 44.02 ± 0.08 |
| llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 40.13 ± 0.12 |
| llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 500.55 ± 0.32 |
| llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 448.02 ± 0.27 |
* q6_0: can now be used for kv-cache on Metal
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 3 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 3 | ||||
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 11 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 42 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/cpy.cu | 50 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 44 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 33 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 139 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 139 | ||||
-rw-r--r-- | ggml/src/ggml-quants.h | 5 | ||||
-rw-r--r-- | ggml/src/ggml.c | 26 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 166 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 3 |
19 files changed, 678 insertions, 13 deletions
diff --git a/common/common.cpp b/common/common.cpp index 6c298d2d..75dd78e6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2242,6 +2242,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "q5_1") { return GGML_TYPE_Q5_1; } + if (s == "q6_0") { + return GGML_TYPE_Q6_0; + } throw std::runtime_error("Invalid cache type: " + s); } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index fc77be50..9e4fd266 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -327,6 +327,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "iq4_nl") { return GGML_TYPE_IQ4_NL; } + if (s == "q6_0") { + return GGML_TYPE_Q6_0; + } return GGML_TYPE_COUNT; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c11b8631..2b240299 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -20,6 +20,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", }, { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, + { "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", }, { "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 36cc531f..08fe6a3e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -392,6 +392,7 @@ extern "C" { GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, // + GGML_TYPE_Q6_0 = 133, GGML_TYPE_IQ1_BN = 134, GGML_TYPE_IQ2_BN = 135, GGML_TYPE_Q8_K64 = 136, @@ -447,6 +448,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors // + GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN = 129, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_K = 130, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index bb0c4864..02ecf071 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -95,6 +95,9 @@ typedef sycl::half2 ggml_half2; #define QI5_1 (QK5_1 / (4 * QR5_1)) #define QR5_1 2 +#define QI6_0 (QK6_0 / (4 * QR6_0)) +#define QR6_0 2 + #define QI8_0 (QK8_0 / (4 * QR8_0)) #define QR8_0 1 @@ -187,6 +190,14 @@ typedef struct { } block_q5_1; static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); +#define QK6_0 32 +typedef struct { + ggml_half d; // delta + uint8_t qh[QK6_0/4]; // 5+6-th bit of quants + uint8_t qs[QK6_0/2]; // nibbles / quants +} block_q6_0; +static_assert(sizeof(block_q6_0) == sizeof(ggml_half) + QK6_0/2 + QK6_0/4, "wrong q6_0 block size/padding"); + #define QK8_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 617dd58f..64cc7592 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2807,6 +2807,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -2880,6 +2881,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q6_0) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d75b219b..d7e9c529 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -376,6 +376,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q6_0> { + static constexpr int qk = QK6_0; + static constexpr int qr = QR6_0; + static constexpr int qi = QI6_0; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> { static constexpr int qk = QK8_0; static constexpr int qr = QR8_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c74b030b..7089a6df 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -129,6 +129,36 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t } } +template<typename dst_t> +static __global__ void dequantize_block_q6_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q6_0 * x = (const block_q6_0 *)vx + ib; + const float d = __half2float(x->d); + const float dm = -32*d; + + const uint8_t * qs = x->qs + 4*il; + const uint8_t * qh = x->qh + 4*(il%2); + + for (int l = 0; l < 4; ++l) { + const uint8_t h = qh[l] >> 4*(il/2); + y[l+ 0] = d * ((qs[l] & 0xF) | ((h << 4) & 0x30)) + dm; + y[l+16] = d * ((qs[l] >> 4) | ((h << 2) & 0x30)) + dm; + } +} + //================================== k-quants template<typename dst_t> @@ -768,6 +798,14 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n } template<typename dst_t> +static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + dequantize_block_q6_0<<<nb, 32, 0, stream>>>(vx, y, nb32); +} + +template<typename dst_t> static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb32 = k / 32; @@ -1004,6 +1042,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; case GGML_TYPE_Q5_1: return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q6_0: + return dequantize_row_q6_0_cuda; case GGML_TYPE_Q8_0: if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { return dequantize_block_q8_0_f16_cuda; @@ -1074,6 +1114,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; case GGML_TYPE_Q5_1: return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q6_0: + return dequantize_row_q6_0_cuda; case GGML_TYPE_Q8_0: return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; case GGML_TYPE_Q2_K: diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1a84a4cb..0b269a86 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -221,6 +221,41 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { memcpy(dsti->qh, &qh, sizeof(qh)); } +static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q6_0 * dsti = (block_q6_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK6_0; ++j) { + const float v = xi[j]; + const float av = fabsf(xi[j]); + if (amax < av) { + amax = av; + vmax = v; + } + } + + const float d = vmax / -32; + const float id = d ? 1.0f/d : 0.0f; + + dsti->d = d; + memset(dsti->qh, 0, QK6_0/4); + + for (int j = 0; j < QK6_0/2; ++j) { + const float x0 = xi[0 + j]*id; + const float x1 = xi[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f)); + const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f)); + + dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); + dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4))); + } +} + static __device__ const int8_t iq4nl_index[241] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, @@ -397,6 +432,17 @@ static void ggml_cpy_f32_q5_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_f32_q6_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK6_0 == 0); + const int num_blocks = ne / QK6_0; + cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -466,6 +512,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { + ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { @@ -505,6 +553,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { + return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16<cpy_1_f32_f16>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 5f932fef..15e8fb5a 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : + type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 : type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : @@ -34,6 +35,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : + type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ : type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : @@ -232,6 +234,13 @@ static void mul_mat_vec_q5_1_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_q6_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -384,6 +393,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_Q5_1: mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_Q6_0: + mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_Q8_0: mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index b1b465a3..7baabb7a 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -48,6 +48,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); } +#define VDR_Q6_0_Q8_1_MMVQ 2 +#define VDR_Q6_0_Q8_1_MMQ 4 + +template <int vdr> static __device__ __forceinline__ float vec_dot_q6_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d6, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = ((vl[i] >> 0) & 0x0F0F0F0F) | ((vh[i] << 4) & 0x30303030); + const int vi1 = ((vl[i] >> 4) & 0x0F0F0F0F) | ((vh[i] << 2) & 0x30303030); + + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d6 * (sumi * ds8f.x - (32.f*vdr/QI6_0) * ds8f.y); +} + #define VDR_Q4_1_Q8_1_MMVQ 2 #define VDR_Q4_1_Q8_1_MMQ 4 @@ -549,6 +573,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds); } +static __device__ __forceinline__ float vec_dot_q6_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q6_0 * bq6_0 = (const block_q6_0 *) vbq + kbx; + + int vl[VDR_Q6_0_Q8_1_MMVQ]; + int vh[VDR_Q6_0_Q8_1_MMVQ]; + int u[2*VDR_Q6_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q6_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_b2(bq6_0->qs, iqs + i); + vh[i] = get_int_b2(bq6_0->qh, i) >> 4*(iqs/2); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI6_0); + } + + return vec_dot_q6_0_q8_1_impl<VDR_Q6_0_Q8_1_MMVQ>(vl, vh, u, bq6_0->d, bq8_1->ds); +} + static __device__ __forceinline__ float vec_dot_q4_1_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 774314df..dcdd0efe 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -81,6 +81,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, @@ -121,6 +122,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, @@ -155,6 +157,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, @@ -186,6 +189,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, @@ -217,6 +221,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, @@ -271,6 +276,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, @@ -603,6 +609,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0, get_rows_q6_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); @@ -643,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32, mul_mv_q6_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); @@ -677,6 +685,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32, mul_mv_id_q6_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); @@ -708,6 +717,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32, mul_mm_q6_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); @@ -739,6 +749,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32, mul_mm_id_q6_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); @@ -793,6 +804,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); @@ -960,6 +972,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_IQ4_NL: return true; default: @@ -1910,6 +1923,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; @@ -2028,6 +2042,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 8; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; + case GGML_TYPE_Q6_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32].pipeline; + } break; case GGML_TYPE_Q8_0: { nth0 = 8; @@ -2200,7 +2220,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| @@ -2293,6 +2313,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; @@ -2398,6 +2419,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 8; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; + case GGML_TYPE_Q6_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; + } break; case GGML_TYPE_Q8_0: { nth0 = 8; @@ -2581,7 +2608,7 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t _ne1 = 1; const int tgz = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| @@ -2632,6 +2659,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; @@ -3293,6 +3321,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; default: GGML_ABORT("not implemented"); }; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c1e11047..225fa5f1 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1281,8 +1281,30 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// function for calculate inner product between half a q6_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q6 quants begin (0 or QK6_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qh = (device const uint16_t *)qb_curr->qh; + device const uint16_t * qs = (device const uint16_t *)qb_curr->qs + il/2; + + const int shift = 4*(il/8); + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << (4-shift)) & 0x0030)) + + yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << (4-shift)) & 0x3000)); + acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << (6-shift)) & 0x0300)) + + yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << (6-shift)) & 0x30000)); } - return d * (acc[0] + acc[1]) + sumy * m; + return d * (sumy * -32.f + acc[0] + acc[1]); } // putting them in the kernel cause a significant performance penalty @@ -1464,6 +1486,31 @@ kernel void kernel_mul_mv_q5_1_f32( mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q6_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl<block_q6_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} #define NB_Q8_0 8 @@ -3480,6 +3527,77 @@ kernel void kernel_cpy_f32_q5_1( } } +kernel void kernel_cpy_f32_q6_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK6_0; + + device block_q6_0 * dst_data = (device block_q6_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK6_0; i00 < ne00; i00 += ntg.x*QK6_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK6_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -32; + const float id = d ? 1.0f/d : 0.0f; + + device block_q6_0 & b6 = dst_data[i00/QK6_0]; + b6.d = d; + device uint16_t * aux16 = (device uint16_t *)b6.qh; + aux16[0] = aux16[1] = aux16[2] = aux16[3] = 0; + + for (int j = 0; j < QK6_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK6_0/2 + j]*id; + + const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f)); + const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f)); + + b6.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); + b6.qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4))); + } + } +} + static inline int best_index_int8(int n, constant float * val, float x) { if (x <= val[0]) return 0; if (x >= val[n-1]) return n-1; @@ -6844,6 +6962,21 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg } template <typename type4x4> +void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float m = -32.h * xb->d; + device const uint8_t * qh = xb->qh; + device const uint8_t * qs = qh + 8; + + for (int i = 0; i < 8; i++) { + reg[i/4][i%4] = d * (((qs[i] >> 4*il) & 0xf) | (((qh[i] >> 2*il) << 4) & 0x30)) + m; + } + for (int i = 0; i < 8; i++) { + reg[2+i/4][i%4] = d * (((qs[i+8] >> 4*il) & 0xf) | ((qh[i] >> 2*il) & 0x30)) + m; + } +} + +template <typename type4x4> void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; @@ -7839,6 +7972,7 @@ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>; +template [[host_name("kernel_get_rows_q6_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_0, 2, dequantize_q6_0>; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>; @@ -7880,6 +8014,7 @@ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>; template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>; +template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_0, 2, dequantize_q6_0>>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>; @@ -7918,6 +8053,7 @@ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>; template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>; +template [[host_name("kernel_mul_mm_id_q6_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q6_0, 2, dequantize_q6_0>>; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>; @@ -8138,6 +8274,7 @@ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; +template [[host_name("kernel_mul_mv_id_q6_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q6_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_tn_f32_impl>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index bef2f73e..f5fff22e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -848,6 +848,59 @@ void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_1_ref(x, y, k); } +void quantize_row_q6_0_ref(const float * restrict x, block_q6_0 * restrict y, int64_t k) { + static const int qk = QK6_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -32; + const float id = d ? 1.0f/d : 0.0f; + + //y[i].d = GGML_FP32_TO_FP16(d); + memset(y[i].qh, 0, qk/4); + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + const float w0 = x0*x0; + const float w1 = x1*x1; + + const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f)); + const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); + y[i].qh[j%(qk/4)] |= (h << 4*(j/(qk/4))); + + const float q0 = (float)xi0 - 32.f; + const float q1 = (float)xi1 - 32.f; + sumqx += w0*x[i*qk + j]*q0 + w1*x[i*qk + qk/2 + j]*q1; + sumq2 += w0*q0*q0 + w1*q1*q1; + } + y[i].d = sumq2 > 0 ? GGML_FP32_TO_FP16(sumqx/sumq2) : GGML_FP32_TO_FP16(d); + } +} + +void quantize_row_q6_0(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q6_0_ref(x, y, k); +} + // reference implementation for deterministic creation of model files void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); @@ -1691,6 +1744,28 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int6 } } +void dequantize_row_q6_0(const block_q6_0 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK6_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t h = x[i].qh[j%(qk/4)] >> 4*(j/(qk/4)); + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | ((h << 4) & 0x30)) - 32; + const int32_t x1 = ((x[i].qs[j] >> 4) | ((h << 2) & 0x30)) - 32; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK8_0; @@ -3429,6 +3504,54 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +static void quantize_row_q6_0_impl(const float * restrict x, block_q6_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK6_0 == 32, "QK6_0 must be 32"); + + float weight[QK6_0]; + int8_t L[QK6_0]; + + float sigma2 = 0; + if (quant_weights) { + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + sigma2 = sum_x2/n_per_row; + } + + const int64_t nb = n_per_row/QK6_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK6_0 * ib; + if (quant_weights) { + const float * qw = quant_weights + QK6_0 * ib; + for (int j = 0; j < QK6_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < QK6_0; ++j) weight[j] = xb[j]*xb[j]; + } + float d = make_qx_quants(QK6_0, 32, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); + + memset(y[ib].qh, 0, QK6_0/4); + + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); + y[ib].qh[j%8] |= (h << 4*(j/8)); + } + } +} + +size_t quantize_q6_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q6_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q6_0_impl(src, (block_q6_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); @@ -5383,6 +5506,21 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r *s = sumf; } +void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT +#ifdef __AVX2__ + const enum ggml_type vec_dot_type = GGML_TYPE_Q8_1; +#else + const enum ggml_type vec_dot_type = GGML_TYPE_Q8_0; +#endif + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q6_0, vx, bx, vec_dot_type, vy, by, s, bs, 0, 1)) { + return; + } +#endif + // TODO + *s = 0; +} + void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { @@ -15020,6 +15158,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q6_0: break; case GGML_TYPE_IQ2_K: break; case GGML_TYPE_IQ3_K: break; case GGML_TYPE_IQ4_K: break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 775aa875..bad7e9d9 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -25,6 +25,7 @@ void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_REST void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_0_ref(const float * GGML_RESTRICT x, block_q6_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -48,6 +49,7 @@ void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -72,6 +74,7 @@ void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRI void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); //void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q6_0(const block_q6_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -98,6 +101,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q6_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -140,6 +144,7 @@ size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q6_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void iq2xs_init_impl(enum ggml_type type); void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee83fc43..d31713df 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -799,6 +799,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q6_0] = { + .type_name = "q6_0", + .blck_size = QK6_0, + .type_size = sizeof(block_q6_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_0, + .from_float = quantize_row_q6_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref, + .vec_dot = ggml_vec_dot_q6_0_q8_0, +#if GGML_USE_IQK_MULMAT && defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1, +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q8_0] = { .type_name = "q8_0", .blck_size = QK8_0, @@ -3788,6 +3805,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; @@ -10237,6 +10255,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -10623,6 +10642,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: @@ -10760,6 +10780,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: @@ -13858,6 +13879,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -14234,6 +14256,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: @@ -14505,6 +14528,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: @@ -15103,6 +15127,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: @@ -21899,6 +21924,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d16f01d9..0c1c1625 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3228,6 +3228,16 @@ struct Q5_1_Dequantizer { return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; +struct Q6_1_Dequantizer { + Dequantizer4bit b4; + const __m256i mh = _mm256_set1_epi8(0x30); + inline __m256i dequant(const block_q6_0 * x) const { + uint64_t aux64; std::memcpy(&aux64, x->qh, 8); + auto h128 = _mm_set_epi64x(aux64, aux64 << 4); + auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 2), h128); + return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(h256, mh)); + } +}; template <typename Q, typename Scales, typename Dequantizer> struct Q_Unpacker { @@ -3332,6 +3342,11 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_ using Sum4T = Sum4Type1; inline static int block_size() { return QK4_1; } }; +struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_1_Dequantizer> { + Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ81; + inline static int block_size() { return QK5_0; } +}; // float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result @@ -3628,7 +3643,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker> || std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> || - std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) { + std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> || + std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) { m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>; @@ -3893,8 +3909,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); - //MulMat::set_functions<Q4_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q4_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -3905,8 +3919,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); - //MulMat::set_functions<Q5_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q5_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -3915,10 +3927,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { MulMat::set_functions<Q5_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q6_0: + assert (ne00 % QK6_0 == 0); + MulMat::set_functions<Q6_0_1_Unpacker>(mm); + expected_typeB = GGML_TYPE_Q8_1; + break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); - //MulMat::set_functions<Q8_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q8_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -5417,6 +5432,34 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> { //ggml_half aux[4]; }; +struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> { + + DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh8 = vld1_u8(x[i].qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32); + q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m32 = vdupq_n_s8(-32); + const uint8x16_t hmask = vdupq_n_u8(0x30); +}; + struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> { DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} @@ -6325,7 +6368,8 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> || - std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL>) { + std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> || + std::is_same_v<Dequantizer, DequantizerQ60>) { m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>; @@ -6492,6 +6536,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions<DequantizerQ51>(m); expected_Btype = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q6_0: + MulMat::set_functions<DequantizerQ60>(m); + expected_Btype = GGML_TYPE_Q8_0; + break; case GGML_TYPE_Q8_0: MulMat::set_functions<DequantizerQ80>(m); expected_Btype = GGML_TYPE_Q8_0; @@ -7227,6 +7275,64 @@ struct HelperIQ4nl final : public BaseHelper<step> { #endif }; +template <int D, int step> +struct HelperQ60 final : public BaseHelper<step> { +#ifdef __aarch64__ + using block_q8 = block_q8_0; +#else + using block_q8 = block_q8_1; +#endif + using Base = BaseHelper<step>; + HelperQ60(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; +#ifdef __aarch64__ + // TODO + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto qh8 = vld1_u8(dl->qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + auto qs = vld1q_u8(dl->qs); + qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l); + qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh)); + qs = vaddq_s8(qs, m32); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs)))); +#else + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + auto bl = _mm_loadu_si128((const __m128i *)dl->qs); + uint64_t aux64; std::memcpy(&aux64, dl->qh, 8); + auto bh = _mm_set_epi64x(aux64, aux64 << 4); +#ifdef HAVE_FANCY_SIMD + auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32); + auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) { + bl = _mm_srli_epi16(bl, 4); + bh = _mm_srli_epi16(bh, 2); + } + auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32)); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif +#endif + } + +#ifdef __AVX2__ + const __m128i mask_l = _mm_set1_epi8(0x0f); + const __m128i mask_h = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +#else + const uint8x16_t mask_l = vdupq_n_u8(0x0f); + const uint8x16_t mask_h = vdupq_n_u8(0x30); + const int8x16_t m32 = vdupq_n_s8(-32); +#endif +}; + template <int q_step, int k_step> struct FlashMS { // Something goes wrong when storing and manipulating K*Q as fp16. @@ -7759,6 +7865,14 @@ struct FlashQKfp32 { mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); #endif } + else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } else { GGML_ASSERT(false); } @@ -7880,6 +7994,28 @@ struct FlashQKfp32 { #endif } } + else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } else { GGML_ASSERT(false); } @@ -8019,7 +8155,8 @@ struct FlashAttn { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { + std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || + std::is_same_v<KHelper, HelperQ60<D, k_step>>) { compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else { @@ -8364,6 +8501,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperIQ4nl<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> vh(v, stride_v); + iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; default: break; } } @@ -8395,6 +8536,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperIQ4nl<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> kh(k, stride_k); + iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; default: break; } @@ -8404,7 +8549,8 @@ inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; return false; } } diff --git a/include/llama.h b/include/llama.h index 02d94b6c..43c0091e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -167,6 +167,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors // + LLAMA_FTYPE_MOSTLY_Q6_0 = 135, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_BN = 137, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_K = 138, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index dca03ade..eb982125 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3774,6 +3774,7 @@ struct llama_model_loader { case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; @@ -4471,6 +4472,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; @@ -15967,6 +15969,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; |