summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-10 19:17:04 +0300
committerGitHub <noreply@github.com>2024-09-10 19:17:04 +0300
commit72f5dfe12ac2263e47df53daa0f39acd1e2e4fb6 (patch)
treec12a902cb72f5120a6960fde25a26b83fe0c6b91 /ggml/src
parentd17d0c44267bd7d8040626d1006c8377dad4f502 (diff)
AVX2 Flash Attention (#48)
* First version of AVX2 Flash attention I simply took the Zen4 implementation and converted platform specific stuff to methods of a struct providing data loading/storing, conversions, multiply, add, etc. Most likely not optimal as the Zen4 strategy has been designed based on having 32 512-bit registers, so basically we can have 4X more data stored in vector registers compared to AVX2 with 16 x 256-bit. It still gives a small speedup (~4% at 2048 tokens) for Gemma-2b. * Fix Zenn4 parts broken via the AVX2 change * Try smaller q_step - no improvement * Fix ARM_NEON I had forgotten to guard the AVX2/Zen4 implementation against __aarch64__ --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp276
1 files changed, 165 insertions, 111 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 424cba85..3a3b9eba 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -252,6 +252,12 @@ inline int hsum_i32_8(const __m256i a) {
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
+inline float hmax_float_8(__m256 x) {
+ __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
+ return _mm_cvtss_f32(max4);
+}
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
@@ -6395,7 +6401,7 @@ inline __m256 v_tanh(__m256 x) {
#endif
} // namespace
-#ifdef HAVE_FANCY_SIMD
+#ifndef __aarch64__
namespace {
@@ -6414,42 +6420,99 @@ struct BaseHelper {
};
+struct F16 {
+#ifdef HAVE_FANCY_SIMD
+ using Data = __m512;
+ constexpr static int block_size = 16;
+ constexpr static int num_registers = 32;
+ constexpr static int q_step = 8;
+ static inline Data zero() { return _mm512_setzero_ps(); }
+ static inline Data load(const char * ptr, int i) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)ptr + i)); }
+ static inline Data set1(float val) { return _mm512_set1_ps(val); }
+ static inline Data mul(Data v1, Data v2) { return _mm512_mul_ps(v1, v2); }
+ static inline Data sub(Data v1, Data v2) { return _mm512_sub_ps(v1, v2); }
+ static inline Data load(const float * ptr) { return _mm512_loadu_ps(ptr); }
+ static inline void store(float * ptr, Data data) { _mm512_storeu_ps(ptr, data); }
+ static inline float reduce_max(Data data) { return _mm512_reduce_max_ps(data); }
+ static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, v2, prev); }
+ template <int k_step> static inline float reduce_max(const Data * data) {
+ return reduce_T<k_step, _mm512_max_ps, _mm512_reduce_max_ps>(data);
+ }
+ template <int k_step> static inline float reduce_add(const Data * data) {
+ return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
+ }
+#else
+ using Data = __m256;
+ constexpr static int block_size = 8;
+ constexpr static int num_registers = 16;
+ constexpr static int q_step = 8;
+ static inline Data zero() { return _mm256_setzero_ps(); }
+ static inline Data load(const char * ptr, int i) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)ptr + i)); }
+ static inline Data set1(float val) { return _mm256_set1_ps(val); }
+ static inline Data mul(Data v1, Data v2) { return _mm256_mul_ps(v1, v2); }
+ static inline Data load(const float * ptr) { return _mm256_loadu_ps(ptr); }
+ static inline Data sub(Data v1, Data v2) { return _mm256_sub_ps(v1, v2); }
+ static inline void store(float * ptr, Data data) { _mm256_storeu_ps(ptr, data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, v2, prev); }
+ static inline float reduce_max(Data data) { return hmax_float_8(data); }
+ static inline float reduce_add(Data data) { return hsum_float_8(data); }
+ template <int k_step> static inline float reduce_max(const Data * data) {
+ return reduce_T<k_step, _mm256_max_ps, &F16::reduce_max>(data);
+ }
+ template <int k_step> static inline float reduce_add(const Data * data) {
+ return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
+ }
+#endif
+ template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
+ static float reduce_T(const Data * data) {
+ float result;
+ if constexpr (k_step/block_size == 1) {
+ result = Op(data[0]);
+ }
+ else if constexpr (k_step/block_size == 2) {
+ result = Op(Op_combine(data[0], data[1]));
+ }
+ else {
+ auto vmax = Op_combine(data[0], data[1]);
+ for (int l = 2; l < k_step/block_size; ++l) vmax = Op_combine(vmax, data[l]);
+ result = Op(vmax);
+ }
+ return result;
+ }
+};
+
template <int D, int step>
struct HelperF16 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
HelperF16(const char * data, int stride) : Base(data, stride) {}
- inline void load(int l1, __m512 * vk) const {
+ inline void load(int l1, F16::Data * vk) const {
auto dr = Base::lblock(l1);
- for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
+ for (int i = 0; i < D/F16::block_size; ++i) vk[i] = F16::load(dr, i);
}
- inline void load(int l1, int i, __m512& v1, __m512& v2) const {
- auto dr = (const ggml_half *)Base::lblock(l1);
- v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
- v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ //auto dr = (const ggml_half *)Base::lblock(l1);
+ auto dr = Base::lblock(l1);
+ v1 = F16::load(dr, i + 0);
+ v2 = F16::load(dr, i + 1);
}
- inline void load_2(int l1, __m512 * vk) const {
+ inline void load_2(int l1, F16::Data* vk) const {
load(l1+0, vk+0);
load(l1+1, vk+D/16);
}
};
+#ifdef HAVE_FANCY_SIMD
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
static_assert(step == QK8_0);
using Base = BaseHelper<step>;
+ using F16 = HelperF16<D, step>;
HelperQ80(const char * data, int stride) : Base(data, stride) {}
- //inline void load(int l1, __m512 * vk) const {
- // auto dl = (const block_q8_0 *)Base::lblock(l1);
- // for (int i = 0; i < D/32; ++i) {
- // auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
- // vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
- // vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
- // }
- //}
inline void load(int l1, __m512 * vk) const {
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
if constexpr (D >= 128) {
@@ -6596,10 +6659,11 @@ struct HelperQ41 final : public BaseHelper<step> {
const __m128i mask = _mm_set1_epi8(0xf);
};
+#endif
template <int q_step, int k_step>
struct FlashMS {
- FlashMS(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
+ FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
for (int j = 0; j < q_step; ++j) {
@@ -6607,18 +6671,18 @@ struct FlashMS {
}
}
- inline void update_M_S(int j, __m512 * vk) {
+ inline void update_M_S(int j, F16::Data * vk) {
if (softcap <= 0.0f) {
- for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
} else {
- auto v_softcap = _mm512_set1_ps(softcap);
- for (int l = 0; l < k_step/16; ++l) {
- auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
- vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val)));
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto val = F16::load(cache + k_step*j + F16::block_size*l);
+ vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
}
}
- float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
+ float smax = F16::reduce_max<k_step>(vk);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
@@ -6628,7 +6692,7 @@ struct FlashMS {
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
- vms[j] = _mm512_set1_ps(m);
+ vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
@@ -6637,40 +6701,22 @@ struct FlashMS {
}
M[j] = smax;
}
- auto vm = _mm512_set1_ps(M[j]);
- for (int l = 0; l < k_step/16; ++l) {
- vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
- _mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
+ auto vm = F16::set1(M[j]);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ vk[l] = v_expf(F16::sub(vk[l], vm));
+ F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
}
- S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
+ S[j] += F16::reduce_add<k_step>(vk);
}
float cache[q_step*k_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
- __m512 vms[q_step];
- const __m512 vscale;
+ F16::Data vms[q_step];
+ const F16::Data vscale;
const float softcap;
const ggml_half h_inf;
- typedef __m512 (*combine_t)(__m512, __m512);
- typedef float (*reduce_t)(__m512);
- template <reduce_t Op, combine_t Op_combine>
- static inline float reduce_T(const __m512 * vals) {
- float result;
- if constexpr (k_step/16 == 1) {
- result = Op(vals[0]);
- }
- else if constexpr (k_step/16 == 2) {
- result = Op(Op_combine(vals[0], vals[1]));
- }
- else {
- auto vmax = Op_combine(vals[0], vals[1]);
- for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
- result = Op(vmax);
- }
- return result;
- }
};
template <int D, int q_step, int k_step>
@@ -6680,78 +6726,78 @@ struct FlashQKV {
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- __m512 vk[2*q_step];
- for (int i = 0; i < D/16; i += 2) {
+ F16::Data vk[2*q_step];
+ for (int i = 0; i < D/F16::block_size; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (fms.need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
+ vk[2*j+0] = vk[2*j+1] = F16::zero();
} else {
auto R = qkv_cache + D*j;
- vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
- vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
+ vk[2*j+0] = F16::load(R + F16::block_size*i);
+ vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
if (fms.need_scaling[j] == 1) {
- vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]);
- vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]);
+ vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
+ vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
}
}
}
- __m512 v1, v2;
+ F16::Data v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
vh.load(l1, i, v1, v2);
for (int j = 0; j < q_step; ++j) {
- auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]);
- vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
- vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ auto vs = F16::set1(fms.cache[k_step*j + l1]);
+ vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
+ vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
- _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
- _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
+ F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
+ F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
}
}
}
template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- __m512 vk[2*q_step];
- for (int i = 0; i < D/16; i += 2) {
+ F16::Data vk[2*q_step];
+ for (int i = 0; i < D/F16::block_size; i += 2) {
for (int j = 0; j < nq1; ++j) {
if (fms.need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
+ vk[2*j+0] = vk[2*j+1] = F16::zero();
} else {
auto R = qkv_cache + D*j;
- vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
- vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
+ vk[2*j+0] = F16::load(R + F16::block_size*i);
+ vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
if (fms.need_scaling[j] == 1) {
- vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]);
- vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]);
+ vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
+ vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
}
}
}
- __m512 v1, v2;
+ F16::Data v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
vh.load(l1, i, v1, v2);
- for (int j = 0; j < nq1; ++j) {
- auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]);
- vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
- vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ for (int j = 0; j < q_step; ++j) {
+ auto vs = F16::set1(fms.cache[k_step*j + l1]);
+ vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
+ vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
}
}
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
- _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
- _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
+ F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
+ F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
}
}
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
- auto norm = _mm512_set1_ps(1/fms.S[j]);
- for (int i = 0; i < D/16; ++i) {
- auto r = _mm512_loadu_ps(R + 16*i);
- _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
+ auto norm = F16::set1(1/fms.S[j]);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ auto r = F16::load(R + F16::block_size*i);
+ F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
}
}
@@ -6778,15 +6824,15 @@ struct FlashQKV {
template <int D, int q_step, int k_step>
struct FlashQKfp32 {
- static_assert(D%16 == 0 && D <= 256);
- static_assert(k_step%16 == 0);
+ static_assert(D%F16::block_size == 0 && D <= 256);
+ static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
- constexpr static bool is_small_head = D <= 128;
+ constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
template <bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
- __m512 * qv, __m512 * vk, FlashMS<q_step, k_step>& fms) {
+ F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
@@ -6795,22 +6841,22 @@ struct FlashQKfp32 {
return;
}
auto qr = q + m1*stride_q;
- for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
+ for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i);
if (mp[l1+0] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
- fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ auto vsum = F16::zero();
+ for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum);
}
if (mp[l1+1] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
- fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ auto vsum = F16::zero();
+ for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum);
}
}
template <bool small = is_small_head, class = std::enable_if<!small>>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
- __m512 * vk, FlashMS<q_step, k_step>& fms) {
+ F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
@@ -6819,18 +6865,18 @@ struct FlashQKfp32 {
return;
}
auto qr = q + m1*stride_q;
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) {
- vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
+ auto vsum = F16::zero();
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i));
}
- fms.cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
+ fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
- __m512 qv[D/16];
- __m512 vk[D/8];
+ F16::Data qv[D/F16::block_size];
+ F16::Data vk[D/(F16::block_size/2)];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
@@ -6842,7 +6888,7 @@ struct FlashQKfp32 {
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- __m512 vk[D/16];
+ F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
@@ -6854,8 +6900,8 @@ struct FlashQKfp32 {
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
- __m512 qv[D/16];
- __m512 vk[D/8];
+ F16::Data qv[D/F16::block_size];
+ F16::Data vk[D/(F16::block_size/2)];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
@@ -6867,7 +6913,7 @@ struct FlashQKfp32 {
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- __m512 vk[D/16];
+ F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
@@ -6885,7 +6931,7 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
}
- __m512 vk[k_step/16];
+ F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk);
}
@@ -6900,7 +6946,7 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
}
- __m512 vk[k_step/16];
+ F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < nq; ++j) {
fms.update_M_S(j, vk);
}
@@ -6959,8 +7005,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
template <int D, int q_step, int k_step>
struct FlashAttn {
- static_assert(D%16 == 0 && D <= 256);
- static_assert(k_step%16 == 0);
+ static_assert(D%F16::block_size == 0 && D <= 256);
+ static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
@@ -7292,6 +7338,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperF16<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+#ifdef HAVE_FANCY_SIMD
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -7304,6 +7351,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ41<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+#endif
default: break;
}
}
@@ -7319,6 +7367,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperF16<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+#ifdef HAVE_FANCY_SIMD
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -7331,17 +7380,22 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ41<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+#endif
default: break;
}
}
inline bool flash_attn_is_supported(ggml_type type) {
+#ifdef HAVE_FANCY_SIMD
#ifdef __AVX512BF16__
return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
#else
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
#endif
+#else
+ return type == GGML_TYPE_F16;
+#endif
}
}
@@ -7395,19 +7449,19 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
switch (D) {
case 64:
- iqk_flash_helper_T< 64, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}