diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 108 |
1 files changed, 41 insertions, 67 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e8150ec5..5750b952 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7,20 +7,14 @@ // SPDX-License-Identifier: MIT // -#if defined IQK_IMPLEMENT -#undef IQK_IMPLEMENT -#endif +#include "iqk_config.h" -#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD -#define IQK_IMPLEMENT -#endif +#if defined IQK_IMPLEMENT #include <cstring> #include <type_traits> #include <vector> -#if defined IQK_IMPLEMENT - #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" @@ -100,26 +94,6 @@ struct Perf { }; #endif -#ifdef _MSC_VER -#define IQK_NOINLINE __declspec(noinline) -#define IQK_ALWAYS_INLINE inline -#if !defined __x86_64__ && defined _M_X64 -#define __x86_64__ -#endif -#else -#define IQK_NOINLINE __attribute__((__noinline__)) -#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) -#endif - -#if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif -#endif - namespace { typedef struct { @@ -1472,7 +1446,7 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { template <typename Q8, typename Bits> inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); @@ -1489,7 +1463,7 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, } #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); @@ -2747,7 +2721,7 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { auto h1 = _mm256_andnot_si256(mask4, hbits); auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); - auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff)); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff; return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), @@ -2843,7 +2817,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { const __m256i values; __m256i data[4]; const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001); - const __m256i bmask = _mm256_set1_epi16(0xfffe); + const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe; const __m128i mask = _mm_set1_epi16(254); const __m128i m127 = _mm_set1_epi16(-127); const __m128i m128 = _mm_set1_epi16(-128); @@ -7049,7 +7023,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI template <typename Bits> inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7065,7 +7039,7 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons sumi[1] = _mm256_add_epi32(p2, p4); #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7282,7 +7256,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7304,7 +7278,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); #else @@ -7328,7 +7302,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), @@ -7349,7 +7323,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -7401,7 +7375,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7413,7 +7387,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i acc[2] = {}; for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), @@ -7436,7 +7410,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); @@ -7455,7 +7429,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare2(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -8537,7 +8511,7 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase { xv[1] = load1(ix+1, i); xv[2] = load1(ix+2, i); xv[3] = load1(ix+3, i); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); @@ -14749,7 +14723,7 @@ struct BaseHelper { }; struct F16 { -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ using Data = __m512; constexpr static int block_size = 16; constexpr static int num_registers = 32; @@ -14910,7 +14884,7 @@ struct HelperQ8KV final : public BaseHelper<step> { v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else auto vd = F16::set1(q8->d); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1)))); #else @@ -14945,7 +14919,7 @@ struct HelperQ80 final : public BaseHelper<step> { v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else @@ -15215,7 +15189,7 @@ struct HelperQ40 final : public BaseHelper<step> { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15260,7 +15234,7 @@ struct HelperQ41 final : public BaseHelper<step> { auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_and_si128(q, mask); auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); @@ -15306,7 +15280,7 @@ struct HelperIQ4nl final : public BaseHelper<step> { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask)); auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask)); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15361,7 +15335,7 @@ struct HelperQ60 final : public BaseHelper<step> { auto bl = _mm_loadu_si128((const __m128i *)dl->qs); uint64_t aux64; std::memcpy(&aux64, dl->qh, 8); auto bh = _mm_set_epi64x(aux64, aux64 << 4); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32); auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15537,6 +15511,22 @@ struct FlashMS { } return F16::reduce_max<k_step>(vk); } + static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { + auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + auto m256 = _mm256_cvtepi16_epi32(m128); + auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + } +#ifdef __AVX512F__ + static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { + auto m256 = _mm256_loadu_si256((const __m256i *)mask+l); + m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256()); + auto m512 = _mm512_cvtepi16_epi32(m256); + auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16))); + return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf)); + } +#endif inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) { #ifdef HAVE_FANCY_SIMD auto vzero = _mm256_set1_epi16(0); @@ -15554,15 +15544,9 @@ struct FlashMS { } } #else - auto vzero = _mm_set1_epi16(0); auto vinf = F16::set1(-INFINITY); for (int l = 0; l < k_step/F16::block_size; ++l) { - auto m128 = _mm_loadu_si128((const __m128i *)mask + l); - m128 = _mm_cmpeq_epi16(m128, vzero); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l); - vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf); } if (softcap <= 0) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); @@ -15630,14 +15614,12 @@ struct FlashQKV { for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); } } - //F16::Data v[8]; F16::Data v0, v1; for (int l = 0; l < k_step; l += 4) { auto vs0 = F16::set1(fms.cache[l + 0]); auto vs1 = F16::set1(fms.cache[l + 1]); auto vs2 = F16::set1(fms.cache[l + 2]); auto vs3 = F16::set1(fms.cache[l + 3]); - //auto vs = F16::set4(fms.cache + l); for (int i = 0; i < D/F16::block_size; i += 2) { vh.load(l+0, i, v0, v1); vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); @@ -15651,14 +15633,6 @@ struct FlashQKV { vh.load(l+3, i, v0, v1); vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); - //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); - //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); - //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); - //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); - //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); - //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); - //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); - //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); } } for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); |