summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp248
1 files changed, 202 insertions, 46 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3a3b9eba..9267f0f3 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) {
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
+inline float16x8_t v_expf(float16x8_t x) {
+ auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
+ auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+}
inline float32x4_t v_tanh(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
@@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
+inline float32x4_t v_tanh(float16x8_t x) {
+ auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
+ auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+}
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -6401,8 +6411,6 @@ inline __m256 v_tanh(__m256 x) {
#endif
} // namespace
-#ifndef __aarch64__
-
namespace {
template <int k_step>
@@ -6442,7 +6450,7 @@ struct F16 {
template <int k_step> static inline float reduce_add(const Data * data) {
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
}
-#else
+#elif defined __AVX2__
using Data = __m256;
constexpr static int block_size = 8;
constexpr static int num_registers = 16;
@@ -6463,6 +6471,41 @@ struct F16 {
template <int k_step> static inline float reduce_add(const Data * data) {
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
}
+#else
+ using Data = float16x8_t;
+ constexpr static int block_size = 8;
+ constexpr static int num_registers = 32;
+ constexpr static int q_step = 8;
+ static inline Data zero() { return vdupq_n_f16(0); }
+ static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
+ static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
+ static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
+ static inline Data load(const float * ptr) {
+ auto val1 = vld1q_f32(ptr);
+ auto val2 = vld1q_f32(ptr+4);
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+ }
+ static inline Data set1(float val) { return vdupq_n_f16(val); }
+ static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
+ static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
+ static inline void store(float * ptr, Data data) {
+ vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
+ vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
+ }
+ static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
+ static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
+ static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
+ static inline float reduce_add(Data data) {
+ auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
+ return vaddvq_f32(vcvt_f32_f16(sum));
+ }
+ template <int k_step> static inline float reduce_max(const Data * data) {
+ return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data);
+ }
+ template <int k_step> static inline float reduce_add(const Data * data) {
+ return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data);
+ }
#endif
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
static float reduce_T(const Data * data) {
@@ -6663,6 +6706,17 @@ struct HelperQ41 final : public BaseHelper<step> {
template <int q_step, int k_step>
struct FlashMS {
+// Something goes wrong when storing and manipulating K*Q as fp16.
+// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
+// As I wasn't able to find where we lose precision, let's comment this out
+// for now and do the K*Q part in fp32.
+//#ifdef __aarch64__
+// using cache_t = float16_t;
+//#else
+// using cache_t = float;
+//#endif
+ using cache_t = float;
+
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
@@ -6671,6 +6725,75 @@ struct FlashMS {
}
}
+#ifdef __aarch64__
+ inline void update_M_S(int j, float32x4_t * vk) {
+ float32x4_t vmax = vdupq_n_f32(-INFINITY);
+ // Something goes wrong when storing and manipulating K*Q as fp16.
+ // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
+ // As I wasn't able to find where we lose precision, let's comment this out
+ // for now and do the K*Q part in fp32.
+ //if (softcap <= 0.0f) {
+ // for (int l = 0; l < k_step/F16::block_size; ++l) {
+ // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
+ // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
+ // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
+ // }
+ //} else {
+ // auto v_softcap = vdupq_n_f32(softcap);
+ // for (int l = 0; l < k_step/F16::block_size; ++l) {
+ // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
+ // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
+ // vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
+ // vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
+ // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
+ // }
+ //}
+ auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
+ if (softcap <= 0.0f) {
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ } else {
+ auto v_softcap = vdupq_n_f32(softcap);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
+ vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ }
+
+ float smax = vmaxvq_f32(vmax);
+ if (smax == -INFINITY) {
+ std::memset(cache + k_step*j, 0, k_step*sizeof(float));
+ need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
+ return;
+ }
+ need_scaling[j] = 0;
+ if (smax > M[j]) {
+ if (M[j] > -INFINITY) {
+ float m = expf(M[j] - smax);
+ vms[j] = F16::set1(m);
+ need_scaling[j] = 1;
+ S[j] *= m;
+ } else {
+ need_scaling[j] = 2;
+ S[j] = 0;
+ }
+ M[j] = smax;
+ }
+ auto vm = vdupq_n_f32(M[j]);
+ auto vsum = vdupq_n_f32(0);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = v_expf(vsubq_f32(vk[l], vm));
+ vsum = vaddq_f32(vsum, vk[l]);
+ F16::store(cache + k_step*j + 4*l, vk[l]);
+ }
+ S[j] += vaddvq_f32(vsum);
+ }
+#else
inline void update_M_S(int j, F16::Data * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
@@ -6708,8 +6831,9 @@ struct FlashMS {
}
S[j] += F16::reduce_add<k_step>(vk);
}
+#endif
- float cache[q_step*k_step];
+ cache_t cache[q_step*k_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
F16::Data vms[q_step];
@@ -6722,6 +6846,12 @@ struct FlashMS {
template <int D, int q_step, int k_step>
struct FlashQKV {
+#ifdef __aarch64__
+ using qkv_cache_t = float16_t;
+#else
+ using qkv_cache_t = float;
+#endif
+
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
@@ -6792,7 +6922,7 @@ struct FlashQKV {
}
}
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
for (int i = 0; i < D/F16::block_size; ++i) {
@@ -6819,7 +6949,7 @@ struct FlashQKV {
}
}
- float qkv_cache[D*q_step];
+ qkv_cache_t qkv_cache[D*q_step];
};
template <int D, int q_step, int k_step>
@@ -6828,10 +6958,14 @@ struct FlashQKfp32 {
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
+#ifdef __aarch64__
+ constexpr static bool is_small_head = false;
+#else
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
+#endif
- template <bool small = is_small_head, class = std::enable_if<small>>
- static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
@@ -6854,8 +6988,8 @@ struct FlashQKfp32 {
}
}
- template <bool small = is_small_head, class = std::enable_if<!small>>
- static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
@@ -6872,8 +7006,8 @@ struct FlashQKfp32 {
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
}
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
- static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
+ static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
F16::Data qv[D/F16::block_size];
F16::Data vk[D/(F16::block_size/2)];
@@ -6885,9 +7019,9 @@ struct FlashQKfp32 {
}
}
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
- const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
@@ -6897,8 +7031,8 @@ struct FlashQKfp32 {
}
}
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
- static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
+ static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
F16::Data qv[D/F16::block_size];
F16::Data vk[D/(F16::block_size/2)];
@@ -6910,9 +7044,9 @@ struct FlashQKfp32 {
}
}
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
- const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
@@ -6922,8 +7056,8 @@ struct FlashQKfp32 {
}
}
- template <typename KHelper>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
@@ -6931,14 +7065,21 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
}
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk);
+ }
+#else
F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk);
}
+#endif
}
- template <typename KHelper>
- static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
@@ -6946,11 +7087,33 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
}
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk);
+ }
+#else
F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < nq; ++j) {
fms.update_M_S(j, vk);
}
+#endif
}
+
+#ifdef __aarch64__
+ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
+ for (int i = 0; i < nq; ++i) {
+ for (int j = 0; j < D; j += 8) {
+ auto val1_f32 = vld1q_f32(q + j + 0);
+ auto val2_f32 = vld1q_f32(q + j + 4);
+ auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
+ vst1q_f16(q_f16 + j, val_f16);
+ }
+ q += stride_q;
+ q_f16 += D;
+ }
+ }
+#endif
};
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
@@ -6958,13 +7121,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
FlashMS<q_step, k_step>& fms,
FlashQKV<D, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
+#ifdef __aarch64__
+ float16_t q_f16[D*q_step];
+#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
+#ifdef __aarch64__
+ KQHelper::convert(q_step, stride_q, q, q_f16);
+#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#ifdef __aarch64__
+ KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms);
+#else
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
+#endif
fqkv.accumulate_qkv(vh, fms);
kh.next_block();
vh.next_block();
@@ -6981,9 +7154,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
fms.init_qstep();
kh.reset_block();
vh.reset_block();
+#ifdef __aarch64__
+ KQHelper::convert(n_left, stride_q, q, q_f16);
+#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#ifdef __aarch64__
+ KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms);
+#else
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
+#endif
fqkv.accumulate_qkv(n_left, vh, fms);
kh.next_block();
vh.next_block();
@@ -7469,30 +7649,6 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
return true;
}
-#else
-// TODO
-bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
- [[maybe_unused]] int int_type_v, // type of v
- [[maybe_unused]] int D, // head size
- [[maybe_unused]] int nq, // number of columns in q
- [[maybe_unused]] int nk, // number of rows in k
- [[maybe_unused]] int stride_q, // distance between q columns in bytes
- [[maybe_unused]] int stride_k, // distance between k rows in bytes
- [[maybe_unused]] int stride_v, // distance between v rows in bytes
- [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
- [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
- [[maybe_unused]] const float * q, // q matrix.
- [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
- [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
- [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
- [[maybe_unused]] float scale, // scale applied before softmax
- [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
- [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
- return false;
-}
-
-#endif
-
#else // IQK_IMPLEMENT
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {