diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-02 15:22:13 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-02 15:22:13 +0300 |
commit | cce49832c1b81b4e535e78ff308417ef3a386b18 (patch) | |
tree | 33b10f9344f4656d58cd3ea068233ba75888498d /ggml/src/iqk/iqk_mul_mat.cpp | |
parent | d6909ed6f00f91f20c9ef628085a1a1a6a55c453 (diff) |
Adding Q6_0 (#77)
* Adding q6_0 - basics + AVX2/Zen4 working
* Adding q6_0: CUDA dequantize works, but not mmvq
* Adding q6_0: CUDA mmvq works
* Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache
* Add q6_0 to CPU flash attention
Disappointing result: for LlaMA-3.2-1B, q6_0 K- and V-cache
gives about the same PPL as q8_0 K-cache and q4_0 V-cache,
while needing the exact same RAM.
I.e., what was the point?
* q6_0: slightly better kv-cache result
Better than q8_0+q4_0, but not as good as q8_0+iq4_nl
* q6_0: works on ARM_NEON
* q6_0: dequantize works on Metal, but not vector dot product
* q6_0: it now works on Metal
Outperforms q5_0 by a significant margin. E.g.
| model | size | params | backend | ngl | threads | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | ---------------: |
| llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 44.02 ± 0.08 |
| llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 40.13 ± 0.12 |
| llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 500.55 ± 0.32 |
| llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 448.02 ± 0.27 |
* q6_0: can now be used for kv-cache on Metal
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 166 |
1 files changed, 156 insertions, 10 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d16f01d9..0c1c1625 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3228,6 +3228,16 @@ struct Q5_1_Dequantizer { return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; +struct Q6_1_Dequantizer { + Dequantizer4bit b4; + const __m256i mh = _mm256_set1_epi8(0x30); + inline __m256i dequant(const block_q6_0 * x) const { + uint64_t aux64; std::memcpy(&aux64, x->qh, 8); + auto h128 = _mm_set_epi64x(aux64, aux64 << 4); + auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 2), h128); + return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(h256, mh)); + } +}; template <typename Q, typename Scales, typename Dequantizer> struct Q_Unpacker { @@ -3332,6 +3342,11 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_ using Sum4T = Sum4Type1; inline static int block_size() { return QK4_1; } }; +struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_1_Dequantizer> { + Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ81; + inline static int block_size() { return QK5_0; } +}; // float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result @@ -3628,7 +3643,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker> || std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> || - std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) { + std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> || + std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) { m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>; @@ -3893,8 +3909,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); - //MulMat::set_functions<Q4_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q4_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -3905,8 +3919,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); - //MulMat::set_functions<Q5_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q5_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -3915,10 +3927,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { MulMat::set_functions<Q5_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q6_0: + assert (ne00 % QK6_0 == 0); + MulMat::set_functions<Q6_0_1_Unpacker>(mm); + expected_typeB = GGML_TYPE_Q8_1; + break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); - //MulMat::set_functions<Q8_0_Unpacker>(mm); - //expected_typeB = GGML_TYPE_Q8_0; MulMat::set_functions<Q8_0_1_Unpacker>(mm); expected_typeB = GGML_TYPE_Q8_1; break; @@ -5417,6 +5432,34 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> { //ggml_half aux[4]; }; +struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> { + + DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh8 = vld1_u8(x[i].qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32); + q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m32 = vdupq_n_s8(-32); + const uint8x16_t hmask = vdupq_n_u8(0x30); +}; + struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> { DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} @@ -6325,7 +6368,8 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> || - std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL>) { + std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> || + std::is_same_v<Dequantizer, DequantizerQ60>) { m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>; @@ -6492,6 +6536,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions<DequantizerQ51>(m); expected_Btype = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q6_0: + MulMat::set_functions<DequantizerQ60>(m); + expected_Btype = GGML_TYPE_Q8_0; + break; case GGML_TYPE_Q8_0: MulMat::set_functions<DequantizerQ80>(m); expected_Btype = GGML_TYPE_Q8_0; @@ -7227,6 +7275,64 @@ struct HelperIQ4nl final : public BaseHelper<step> { #endif }; +template <int D, int step> +struct HelperQ60 final : public BaseHelper<step> { +#ifdef __aarch64__ + using block_q8 = block_q8_0; +#else + using block_q8 = block_q8_1; +#endif + using Base = BaseHelper<step>; + HelperQ60(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; +#ifdef __aarch64__ + // TODO + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto qh8 = vld1_u8(dl->qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + auto qs = vld1q_u8(dl->qs); + qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l); + qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh)); + qs = vaddq_s8(qs, m32); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs)))); +#else + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + auto bl = _mm_loadu_si128((const __m128i *)dl->qs); + uint64_t aux64; std::memcpy(&aux64, dl->qh, 8); + auto bh = _mm_set_epi64x(aux64, aux64 << 4); +#ifdef HAVE_FANCY_SIMD + auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32); + auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) { + bl = _mm_srli_epi16(bl, 4); + bh = _mm_srli_epi16(bh, 2); + } + auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32)); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif +#endif + } + +#ifdef __AVX2__ + const __m128i mask_l = _mm_set1_epi8(0x0f); + const __m128i mask_h = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +#else + const uint8x16_t mask_l = vdupq_n_u8(0x0f); + const uint8x16_t mask_h = vdupq_n_u8(0x30); + const int8x16_t m32 = vdupq_n_s8(-32); +#endif +}; + template <int q_step, int k_step> struct FlashMS { // Something goes wrong when storing and manipulating K*Q as fp16. @@ -7759,6 +7865,14 @@ struct FlashQKfp32 { mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); #endif } + else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } else { GGML_ASSERT(false); } @@ -7880,6 +7994,28 @@ struct FlashQKfp32 { #endif } } + else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } else { GGML_ASSERT(false); } @@ -8019,7 +8155,8 @@ struct FlashAttn { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { + std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || + std::is_same_v<KHelper, HelperQ60<D, k_step>>) { compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else { @@ -8364,6 +8501,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperIQ4nl<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> vh(v, stride_v); + iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; default: break; } } @@ -8395,6 +8536,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperIQ4nl<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> kh(k, stride_k); + iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; default: break; } @@ -8404,7 +8549,8 @@ inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; return false; } } |