diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 156 |
1 files changed, 121 insertions, 35 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 9267f0f3..5a8cbce2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6548,22 +6548,23 @@ struct HelperF16 final : public BaseHelper<step> { } }; -#ifdef HAVE_FANCY_SIMD +#if defined __AVX2__ template <int D, int step> struct HelperQ80 final : public BaseHelper<step> { static_assert(step == QK8_0); using Base = BaseHelper<step>; - using F16 = HelperF16<D, step>; + //using F16 = HelperF16<D, step>; HelperQ80(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, __m512 * vk) const { + inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q8_0_x4 *)Base::lblock(l1); if constexpr (D >= 128) { - __m512 vd[4]; + F16::Data vd[4]; for (int ib = 0; ib < D/128; ++ib) { const auto& b8 = dl[ib]; auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d)); auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1); +#ifdef HAVE_FANCY_SIMD auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1); vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0)); vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1)); @@ -6573,29 +6574,57 @@ struct HelperQ80 final : public BaseHelper<step> { vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0)))); vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1)))); } +#else + vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0)); + vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1)); + vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); + vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); + for (int i = 0; i < 4; ++i) { + vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0))))); + vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8))))); + vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16))))); + vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24))))); + } +#endif } } else { for (int i = 0; i < D/32; ++i) { const auto& b8 = dl[i/4]; int ii = i%4; - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii])); + auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii])); +#ifdef HAVE_FANCY_SIMD vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0)))); vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1)))); +#else + vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 0))))); + vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8))))); + vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16))))); + vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24))))); +#endif } } } - inline void load(int l1, int i, __m512& v1, __m512& v2) const { - auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + i/8; - int ii = (i/2)%4; - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d[ii])); + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30 + // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc. + // i*F16::block_size/128 + int j = F16::block_size*i; + auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); + int ii = (j/QK8_0)%4; + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); +#ifdef HAVE_FANCY_SIMD v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1)))); +#else + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); +#endif } - inline void load_2(int l1, __m512 * vk) const { + inline void load_2(int l1, F16::Data * vk) const { load(l1+0, vk+0); - load(l1+1, vk+D/16); + load(l1+1, vk+D/F16::block_size); } }; @@ -6606,11 +6635,11 @@ struct HelperQ40 final : public BaseHelper<step> { HelperQ40(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, __m512 * vk) const { + inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q4_0 *)Base::lblock(l1); if constexpr (D >= 128) { ggml_half aux[4]; - __m512 vd[4]; + F16::Data vd[4]; for (int ib = 0; ib < D/128; ++ib) { for (int i = 0; i < 4; ++i) { auto& b4 = dl[4*ib+i]; @@ -6618,11 +6647,21 @@ struct HelperQ40 final : public BaseHelper<step> { auto q = _mm_loadu_si128((const __m128i *)b4.qs); auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); +#ifdef HAVE_FANCY_SIMD vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)); vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)); +#else + auto ql16 = _mm256_cvtepi8_epi16(ql); + auto qh16 = _mm256_cvtepi8_epi16(qh); + vk[16*ib+4*i+0] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))); + vk[16*ib+4*i+1] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))); + vk[16*ib+4*i+2] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))); + vk[16*ib+4*i+3] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))); +#endif } auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux)); auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1); +#ifdef HAVE_FANCY_SIMD auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1); vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0)); vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1)); @@ -6632,32 +6671,61 @@ struct HelperQ40 final : public BaseHelper<step> { vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]); vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]); } +#else + vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0)); + vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1)); + vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); + vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); + for (int i = 0; i < 4; ++i) { + vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+0]); + vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+1]); + vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+2]); + vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+3]); + } +#endif } } else { for (int i = 0; i < D/32; ++i) { - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d)); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); +#ifdef HAVE_FANCY_SIMD vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + auto ql16 = _mm256_cvtepi8_epi16(ql); + auto qh16 = _mm256_cvtepi8_epi16(qh); + vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16)))); + vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1)))); + vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16)))); + vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1)))); +#endif } } } - inline void load(int l1, int i, __m512& v1, __m512& v2) const { - auto dl = (const block_q4_0 *)Base::lblock(l1) + i/2; - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d)); + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0; + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); +#ifdef HAVE_FANCY_SIMD auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) q = _mm_srli_epi16(q, 4); + auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8)); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif } - inline void load_2(int l1, __m512 * vk) const { + inline void load_2(int l1, F16::Data * vk) const { load(l1+0, vk+0); - load(l1+1, vk+D/16); + load(l1+1, vk+D/F16::block_size); } const __m128i mask = _mm_set1_epi8(0xf); @@ -6671,33 +6739,51 @@ struct HelperQ41 final : public BaseHelper<step> { HelperQ41(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, __m512 * vk) const { + inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q4_1 *)Base::lblock(l1); for (int i = 0; i < D/32; ++i) { - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d)); - auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].m)); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); + auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m)); auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); auto ql = _mm_and_si128(q, mask); auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); +#ifdef HAVE_FANCY_SIMD vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm); +#else + auto ql16 = _mm256_cvtepi8_epi16(ql); + auto qh16 = _mm256_cvtepi8_epi16(qh); + vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))), vm); + vk[4*i+1] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))), vm); + vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm); + vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm); + vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm); +#endif } } - inline void load(int l1, int i, __m512& v1, __m512& v2) const { - auto dl = (const block_q4_1 *)Base::lblock(l1) + i/2; - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d)); - auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->m)); + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1; + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); +#ifdef HAVE_FANCY_SIMD auto ql = _mm_and_si128(q, mask); auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm); +#else + if (j%QK4_1) q = _mm_srli_epi16(q, 4); + auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask)); + v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm); + v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm); +#endif } - inline void load_2(int l1, __m512 * vk) const { + inline void load_2(int l1, F16::Data * vk) const { load(l1+0, vk+0); - load(l1+1, vk+D/16); + load(l1+1, vk+D/F16::block_size); } const __m128i mask = _mm_set1_epi8(0xf); @@ -7518,7 +7604,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperF16<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -7547,7 +7633,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperF16<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -7567,15 +7653,15 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, } inline bool flash_attn_is_supported(ggml_type type) { -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX2__ + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; #ifdef __AVX512BF16__ - return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; -#else - return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; + if (type == GGML_TYPE_BF16) return true; #endif #else - return type == GGML_TYPE_F16; + if (type == GGML_TYPE_F16) return true; #endif + return false; } } |