diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 248 |
1 files changed, 202 insertions, 46 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3a3b9eba..9267f0f3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) { return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); } +inline float16x8_t v_expf(float16x8_t x) { + auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x))); + auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x))); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +} inline float32x4_t v_tanh(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); @@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) { return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } +inline float32x4_t v_tanh(float16x8_t x) { + auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); + auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +} #endif #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6401,8 +6411,6 @@ inline __m256 v_tanh(__m256 x) { #endif } // namespace -#ifndef __aarch64__ - namespace { template <int k_step> @@ -6442,7 +6450,7 @@ struct F16 { template <int k_step> static inline float reduce_add(const Data * data) { return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data); } -#else +#elif defined __AVX2__ using Data = __m256; constexpr static int block_size = 8; constexpr static int num_registers = 16; @@ -6463,6 +6471,41 @@ struct F16 { template <int k_step> static inline float reduce_add(const Data * data) { return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data); } +#else + using Data = float16x8_t; + constexpr static int block_size = 8; + constexpr static int num_registers = 32; + constexpr static int q_step = 8; + static inline Data zero() { return vdupq_n_f16(0); } + static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); } + static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); } + static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); } + static inline Data load(const float * ptr) { + auto val1 = vld1q_f32(ptr); + auto val2 = vld1q_f32(ptr+4); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); + } + static inline Data set1(float val) { return vdupq_n_f16(val); } + static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); } + static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); } + static inline void store(float * ptr, Data data) { + vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data))); + vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data))); + } + static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); } + static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); } + static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); } + static inline float reduce_max(Data data) { return vmaxvq_f16(data); } + static inline float reduce_add(Data data) { + auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data)); + return vaddvq_f32(vcvt_f32_f16(sum)); + } + template <int k_step> static inline float reduce_max(const Data * data) { + return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data); + } + template <int k_step> static inline float reduce_add(const Data * data) { + return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data); + } #endif template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)> static float reduce_T(const Data * data) { @@ -6663,6 +6706,17 @@ struct HelperQ41 final : public BaseHelper<step> { template <int q_step, int k_step> struct FlashMS { +// Something goes wrong when storing and manipulating K*Q as fp16. +// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). +// As I wasn't able to find where we lose precision, let's comment this out +// for now and do the K*Q part in fp32. +//#ifdef __aarch64__ +// using cache_t = float16_t; +//#else +// using cache_t = float; +//#endif + using cache_t = float; + FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} inline void init_qstep() { @@ -6671,6 +6725,75 @@ struct FlashMS { } } +#ifdef __aarch64__ + inline void update_M_S(int j, float32x4_t * vk) { + float32x4_t vmax = vdupq_n_f32(-INFINITY); + // Something goes wrong when storing and manipulating K*Q as fp16. + // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). + // As I wasn't able to find where we lose precision, let's comment this out + // for now and do the K*Q part in fp32. + //if (softcap <= 0.0f) { + // for (int l = 0; l < k_step/F16::block_size; ++l) { + // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); + // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val)); + // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val)); + // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1])); + // } + //} else { + // auto v_softcap = vdupq_n_f32(softcap); + // for (int l = 0; l < k_step/F16::block_size; ++l) { + // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); + // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val)); + // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val)); + // vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0])); + // vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1])); + // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1])); + // } + //} + auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); + if (softcap <= 0.0f) { + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l)); + vmax = vmaxq_f32(vmax, vk[l]); + } + } else { + auto v_softcap = vdupq_n_f32(softcap); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l)); + vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l])); + vmax = vmaxq_f32(vmax, vk[l]); + } + } + + float smax = vmaxvq_f32(vmax); + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = F16::set1(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + auto vm = vdupq_n_f32(M[j]); + auto vsum = vdupq_n_f32(0); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = v_expf(vsubq_f32(vk[l], vm)); + vsum = vaddq_f32(vsum, vk[l]); + F16::store(cache + k_step*j + 4*l, vk[l]); + } + S[j] += vaddvq_f32(vsum); + } +#else inline void update_M_S(int j, F16::Data * vk) { if (softcap <= 0.0f) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); @@ -6708,8 +6831,9 @@ struct FlashMS { } S[j] += F16::reduce_add<k_step>(vk); } +#endif - float cache[q_step*k_step]; + cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; F16::Data vms[q_step]; @@ -6722,6 +6846,12 @@ struct FlashMS { template <int D, int q_step, int k_step> struct FlashQKV { +#ifdef __aarch64__ + using qkv_cache_t = float16_t; +#else + using qkv_cache_t = float; +#endif + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template <typename VHelper> @@ -6792,7 +6922,7 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const { + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); for (int i = 0; i < D/F16::block_size; ++i) { @@ -6819,7 +6949,7 @@ struct FlashQKV { } } - float qkv_cache[D*q_step]; + qkv_cache_t qkv_cache[D*q_step]; }; template <int D, int q_step, int k_step> @@ -6828,10 +6958,14 @@ struct FlashQKfp32 { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); +#ifdef __aarch64__ + constexpr static bool is_small_head = false; +#else constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; +#endif - template <bool small = is_small_head, class = std::enable_if<small>> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + template <bool small = is_small_head, class = std::enable_if<small>, typename q_float> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) { // q index is q_step*i1 + m1 // k index is k_step*k1 + l1 @@ -6854,8 +6988,8 @@ struct FlashQKfp32 { } } - template <bool small = is_small_head, class = std::enable_if<!small>> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, F16::Data * vk, FlashMS<q_step, k_step>& fms) { // q index is q_step*i1 + m1 // k index is k_step*k1 + l1 @@ -6872,8 +7006,8 @@ struct FlashQKfp32 { fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum); } - template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>> - static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float> + static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { F16::Data qv[D/F16::block_size]; F16::Data vk[D/(F16::block_size/2)]; @@ -6885,9 +7019,9 @@ struct FlashQKfp32 { } } - template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>> + template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float> static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask, FlashMS<q_step, k_step>& fms) { + const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { F16::Data vk[D/F16::block_size]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); @@ -6897,8 +7031,8 @@ struct FlashQKfp32 { } } - template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>> - static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float> + static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { F16::Data qv[D/F16::block_size]; F16::Data vk[D/(F16::block_size/2)]; @@ -6910,9 +7044,9 @@ struct FlashQKfp32 { } } - template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>> + template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float> static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask, FlashMS<q_step, k_step>& fms) { + const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { F16::Data vk[D/F16::block_size]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); @@ -6922,8 +7056,8 @@ struct FlashQKfp32 { } } - template <typename KHelper> - static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template <typename KHelper, typename q_float> + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { if constexpr (is_small_head) { mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); @@ -6931,14 +7065,21 @@ struct FlashQKfp32 { else { mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk); + } +#else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { fms.update_M_S(j, vk); } +#endif } - template <typename KHelper> - static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template <typename KHelper, typename q_float> + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { if constexpr (is_small_head) { mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); @@ -6946,11 +7087,33 @@ struct FlashQKfp32 { else { mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk); + } +#else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk); } +#endif } + +#ifdef __aarch64__ + static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { + for (int i = 0; i < nq; ++i) { + for (int j = 0; j < D; j += 8) { + auto val1_f32 = vld1q_f32(q + j + 0); + auto val2_f32 = vld1q_f32(q + j + 4); + auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32)); + vst1q_f16(q_f16 + j, val_f16); + } + q += stride_q; + q_f16 += D; + } + } +#endif }; template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> @@ -6958,13 +7121,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in FlashMS<q_step, k_step>& fms, FlashQKV<D, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { +#ifdef __aarch64__ + float16_t q_f16[D*q_step]; +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { fms.init_qstep(); kh.reset_block(); vh.reset_block(); +#ifdef __aarch64__ + KQHelper::convert(q_step, stride_q, q, q_f16); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#ifdef __aarch64__ + KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); +#else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); +#endif fqkv.accumulate_qkv(vh, fms); kh.next_block(); vh.next_block(); @@ -6981,9 +7154,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in fms.init_qstep(); kh.reset_block(); vh.reset_block(); +#ifdef __aarch64__ + KQHelper::convert(n_left, stride_q, q, q_f16); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#ifdef __aarch64__ + KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); +#else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); +#endif fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -7469,30 +7649,6 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k return true; } -#else -// TODO -bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size - [[maybe_unused]] int nq, // number of columns in q - [[maybe_unused]] int nk, // number of rows in k - [[maybe_unused]] int stride_q, // distance between q columns in bytes - [[maybe_unused]] int stride_k, // distance between k rows in bytes - [[maybe_unused]] int stride_v, // distance between v rows in bytes - [[maybe_unused]] int stride_m, // distance between mask rows (in bytes - [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) - [[maybe_unused]] const float * q, // q matrix. - [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - [[maybe_unused]] float scale, // scale applied before softmax - [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax - [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) - return false; -} - -#endif - #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { |