summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-11 19:55:42 +0300
committerGitHub <noreply@github.com>2024-09-11 19:55:42 +0300
commitc920195edd80ab24beb9a0fd3e2f4df582e735d0 (patch)
tree7874c33edd5a407f72d1d0733093e97283377659
parentd98a6753a63d970ebdc01c2b7b4f198644eef81c (diff)
AVX2 Flash Attention 2 (#50)
* AVX2 Flash Attention: add ability to use Q8_0 for kv-cache * AVX2 Flash Attention: add ability to use Q4_0 for kv-cache * AVX2 Flash Attention: add ability to use Q4_1 for kv-cache * Fix Zen4 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp156
1 files changed, 121 insertions, 35 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 9267f0f3..5a8cbce2 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6548,22 +6548,23 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
-#ifdef HAVE_FANCY_SIMD
+#if defined __AVX2__
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
static_assert(step == QK8_0);
using Base = BaseHelper<step>;
- using F16 = HelperF16<D, step>;
+ //using F16 = HelperF16<D, step>;
HelperQ80(const char * data, int stride) : Base(data, stride) {}
- inline void load(int l1, __m512 * vk) const {
+ inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
if constexpr (D >= 128) {
- __m512 vd[4];
+ F16::Data vd[4];
for (int ib = 0; ib < D/128; ++ib) {
const auto& b8 = dl[ib];
auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d));
auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
+#ifdef HAVE_FANCY_SIMD
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
@@ -6573,29 +6574,57 @@ struct HelperQ80 final : public BaseHelper<step> {
vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0))));
vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1))));
}
+#else
+ vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0));
+ vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1));
+ vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2));
+ vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3));
+ for (int i = 0; i < 4; ++i) {
+ vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0)))));
+ vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8)))));
+ vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16)))));
+ vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24)))));
+ }
+#endif
}
} else {
for (int i = 0; i < D/32; ++i) {
const auto& b8 = dl[i/4];
int ii = i%4;
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
+ auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii]));
+#ifdef HAVE_FANCY_SIMD
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
+#else
+ vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 0)))));
+ vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8)))));
+ vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16)))));
+ vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24)))));
+#endif
}
}
}
- inline void load(int l1, int i, __m512& v1, __m512& v2) const {
- auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + i/8;
- int ii = (i/2)%4;
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d[ii]));
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30
+ // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc.
+ // i*F16::block_size/128
+ int j = F16::block_size*i;
+ auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0);
+ int ii = (j/QK8_0)%4;
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii]));
+#ifdef HAVE_FANCY_SIMD
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1))));
+#else
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8)))));
+#endif
}
- inline void load_2(int l1, __m512 * vk) const {
+ inline void load_2(int l1, F16::Data * vk) const {
load(l1+0, vk+0);
- load(l1+1, vk+D/16);
+ load(l1+1, vk+D/F16::block_size);
}
};
@@ -6606,11 +6635,11 @@ struct HelperQ40 final : public BaseHelper<step> {
HelperQ40(const char * data, int stride) : Base(data, stride) {}
- inline void load(int l1, __m512 * vk) const {
+ inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q4_0 *)Base::lblock(l1);
if constexpr (D >= 128) {
ggml_half aux[4];
- __m512 vd[4];
+ F16::Data vd[4];
for (int ib = 0; ib < D/128; ++ib) {
for (int i = 0; i < 4; ++i) {
auto& b4 = dl[4*ib+i];
@@ -6618,11 +6647,21 @@ struct HelperQ40 final : public BaseHelper<step> {
auto q = _mm_loadu_si128((const __m128i *)b4.qs);
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+#ifdef HAVE_FANCY_SIMD
vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql));
vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh));
+#else
+ auto ql16 = _mm256_cvtepi8_epi16(ql);
+ auto qh16 = _mm256_cvtepi8_epi16(qh);
+ vk[16*ib+4*i+0] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16)));
+ vk[16*ib+4*i+1] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1)));
+ vk[16*ib+4*i+2] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16)));
+ vk[16*ib+4*i+3] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1)));
+#endif
}
auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux));
auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
+#ifdef HAVE_FANCY_SIMD
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
@@ -6632,32 +6671,61 @@ struct HelperQ40 final : public BaseHelper<step> {
vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]);
vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]);
}
+#else
+ vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0));
+ vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1));
+ vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2));
+ vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3));
+ for (int i = 0; i < 4; ++i) {
+ vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+0]);
+ vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+1]);
+ vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+2]);
+ vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+3]);
+ }
+#endif
}
} else {
for (int i = 0; i < D/32; ++i) {
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+#ifdef HAVE_FANCY_SIMD
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ auto ql16 = _mm256_cvtepi8_epi16(ql);
+ auto qh16 = _mm256_cvtepi8_epi16(qh);
+ vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))));
+ vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))));
+ vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))));
+ vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))));
+#endif
}
}
}
- inline void load(int l1, int i, __m512& v1, __m512& v2) const {
- auto dl = (const block_q4_0 *)Base::lblock(l1) + i/2;
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0;
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef HAVE_FANCY_SIMD
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
}
- inline void load_2(int l1, __m512 * vk) const {
+ inline void load_2(int l1, F16::Data * vk) const {
load(l1+0, vk+0);
- load(l1+1, vk+D/16);
+ load(l1+1, vk+D/F16::block_size);
}
const __m128i mask = _mm_set1_epi8(0xf);
@@ -6671,33 +6739,51 @@ struct HelperQ41 final : public BaseHelper<step> {
HelperQ41(const char * data, int stride) : Base(data, stride) {}
- inline void load(int l1, __m512 * vk) const {
+ inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q4_1 *)Base::lblock(l1);
for (int i = 0; i < D/32; ++i) {
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
- auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].m));
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
+ auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m));
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
auto ql = _mm_and_si128(q, mask);
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
+#ifdef HAVE_FANCY_SIMD
vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
+#else
+ auto ql16 = _mm256_cvtepi8_epi16(ql);
+ auto qh16 = _mm256_cvtepi8_epi16(qh);
+ vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))), vm);
+ vk[4*i+1] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))), vm);
+ vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm);
+ vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm);
+ vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm);
+#endif
}
}
- inline void load(int l1, int i, __m512& v1, __m512& v2) const {
- auto dl = (const block_q4_1 *)Base::lblock(l1) + i/2;
- auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
- auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->m));
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef HAVE_FANCY_SIMD
auto ql = _mm_and_si128(q, mask);
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
+#else
+ if (j%QK4_1) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask));
+ v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
+ v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
+#endif
}
- inline void load_2(int l1, __m512 * vk) const {
+ inline void load_2(int l1, F16::Data * vk) const {
load(l1+0, vk+0);
- load(l1+1, vk+D/16);
+ load(l1+1, vk+D/F16::block_size);
}
const __m128i mask = _mm_set1_epi8(0xf);
@@ -7518,7 +7604,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperF16<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
-#ifdef HAVE_FANCY_SIMD
+#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -7547,7 +7633,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperF16<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
-#ifdef HAVE_FANCY_SIMD
+#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -7567,15 +7653,15 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
}
inline bool flash_attn_is_supported(ggml_type type) {
-#ifdef HAVE_FANCY_SIMD
+#ifdef __AVX2__
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
#ifdef __AVX512BF16__
- return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
-#else
- return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
+ if (type == GGML_TYPE_BF16) return true;
#endif
#else
- return type == GGML_TYPE_F16;
+ if (type == GGML_TYPE_F16) return true;
#endif
+ return false;
}
}