diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-01 14:46:40 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 14:46:40 +0300 |
commit | e7f5a86a41ecd469e31b9d003acf91eff43b946f (patch) | |
tree | 63aed20c449890a2c4a9896da76080cdb320ad5f /ggml/src | |
parent | 8457a26f83b2f6acd014449e91bfb60a37fcec0e (diff) |
IQ4_NL kv-cache on the CPU (Zen4/AVX2/ARM_NEON) (#74)
* Be able to use IQ4_NL for KV cache on AVX2/Zen4
* Be able to use IQ4_NL for KV cache on ARM_NEON
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 98 |
1 files changed, 93 insertions, 5 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1183246b..33b2a790 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -542,9 +542,13 @@ struct SimpleBits { __m256i values[4]; }; -__m256i inline load_iq4nl_values_256() { +__m128i inline load_iq4nl_values_128() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; - auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); +} + +__m256i inline load_iq4nl_values_256() { + auto val128 = load_iq4nl_values_128(); return MM256_SET_M128I(val128, val128); } @@ -6247,7 +6251,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn int32x4_t accd[nrc_y]; - const auto m1 = vdupq_n_u8(1); const auto mask2 = vdupq_n_s8(3); for (int ix = 0; ix < nrc_x; ++ix) { @@ -7176,6 +7179,53 @@ struct HelperQ41 final : public BaseHelper<step> { #endif }; +template <int D, int step> +struct HelperIQ4nl final : public BaseHelper<step> { + using Base = BaseHelper<step>; +#ifdef __aarch64__ + using block_q8 = block_q8_0; +#else + using block_q8 = block_q8_1; +#endif + HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto q = vld1q_u8(dl->qs); + q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + q = vqtbl1q_s8(values, q); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); +#else + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + auto q = _mm_loadu_si128((const __m128i *)dl->qs); +#ifdef HAVE_FANCY_SIMD + auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask)); + auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask)); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) q = _mm_srli_epi16(q, 4); + auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask))); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif +#endif + } + +#ifdef __aarch64__ + const uint8x16_t mask = vdupq_n_u8(0xf); + const int8x16_t values; +#else + const __m128i mask = _mm_set1_epi8(0xf); + const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values); +#endif +}; + template <int q_step, int k_step> struct FlashMS { // Something goes wrong when storing and manipulating K*Q as fp16. @@ -7700,6 +7750,14 @@ struct FlashQKfp32 { mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); #endif } + else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } else { GGML_ASSERT(false); } @@ -7799,6 +7857,28 @@ struct FlashQKfp32 { #endif } } + else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } else { GGML_ASSERT(false); } @@ -7938,7 +8018,7 @@ struct FlashAttn { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else { @@ -8279,6 +8359,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ41<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_IQ4_NL: { + HelperIQ4nl<D, k_step> vh(v, stride_v); + iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; default: break; } } @@ -8306,16 +8390,20 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ41<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_IQ4_NL: { + HelperIQ4nl<D, k_step> kh(k, stride_k); + iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; default: break; } } inline bool flash_attn_is_supported(ggml_type type) { - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true; return false; } } |