summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-01 14:46:40 +0300
committerGitHub <noreply@github.com>2024-10-01 14:46:40 +0300
commite7f5a86a41ecd469e31b9d003acf91eff43b946f (patch)
tree63aed20c449890a2c4a9896da76080cdb320ad5f /ggml/src
parent8457a26f83b2f6acd014449e91bfb60a37fcec0e (diff)
IQ4_NL kv-cache on the CPU (Zen4/AVX2/ARM_NEON) (#74)
* Be able to use IQ4_NL for KV cache on AVX2/Zen4 * Be able to use IQ4_NL for KV cache on ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp98
1 files changed, 93 insertions, 5 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 1183246b..33b2a790 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -542,9 +542,13 @@ struct SimpleBits {
__m256i values[4];
};
-__m256i inline load_iq4nl_values_256() {
+__m128i inline load_iq4nl_values_128() {
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+ return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+}
+
+__m256i inline load_iq4nl_values_256() {
+ auto val128 = load_iq4nl_values_128();
return MM256_SET_M128I(val128, val128);
}
@@ -6247,7 +6251,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
int32x4_t accd[nrc_y];
- const auto m1 = vdupq_n_u8(1);
const auto mask2 = vdupq_n_s8(3);
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -7176,6 +7179,53 @@ struct HelperQ41 final : public BaseHelper<step> {
#endif
};
+template <int D, int step>
+struct HelperIQ4nl final : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+#ifdef __aarch64__
+ using block_q8 = block_q8_0;
+#else
+ using block_q8 = block_q8_1;
+#endif
+ HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto q = vld1q_u8(dl->qs);
+ q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ q = vqtbl1q_s8(values, q);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef HAVE_FANCY_SIMD
+ auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask));
+ auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask));
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask)));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __aarch64__
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+ const int8x16_t values;
+#else
+ const __m128i mask = _mm_set1_epi8(0xf);
+ const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
+#endif
+};
+
template <int q_step, int k_step>
struct FlashMS {
// Something goes wrong when storing and manipulating K*Q as fp16.
@@ -7700,6 +7750,14 @@ struct FlashQKfp32 {
mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
else {
GGML_ASSERT(false);
}
@@ -7799,6 +7857,28 @@ struct FlashQKfp32 {
#endif
}
}
+ else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
else {
GGML_ASSERT(false);
}
@@ -7938,7 +8018,7 @@ struct FlashAttn {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
- std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {
@@ -8279,6 +8359,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ41<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_IQ4_NL: {
+ HelperIQ4nl<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
default: break;
}
}
@@ -8306,16 +8390,20 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ41<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_IQ4_NL: {
+ HelperIQ4nl<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
default: break;
}
}
inline bool flash_attn_is_supported(ggml_type type) {
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true;
return false;
}
}