diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 98 |
1 files changed, 97 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 75e5c3c1..d1af9fe8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -96,6 +96,11 @@ struct DataInfo { _mm256_storeu_ps(dst_row(iy) + ix, result); } #endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif #ifdef __ARM_NEON inline void store(int ix, int iy, float32x4_t result) const { vst1q_f32(dst_row(iy) + ix, result); @@ -179,6 +184,7 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_Q8_K_R8: return 8; + case GGML_TYPE_BF16_R16: return 16; default: return 1; } } @@ -3876,6 +3882,72 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } +#ifdef __AVX512BF16__ +template <int nrc_y> +static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%16 == 0); + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + for (int ix = 0; ix < nrc_x/32; ++ix) { + __m512 acc[2*nrc_y] = {}; + __m512bh qx[8]; + const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx); + const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3); + qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0); + qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1); + qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2); + qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + //auto y = _mm512_broadcast_i32x4(y128); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(32*ix+ 0, iy, acc[2*iy+0]); + info.store(32*ix+16, iy, acc[2*iy+1]); + } + } + for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) { + __m512 acc[nrc_y] = {}; + __m512bh qx[4]; + const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + } + } +} +#endif + template <int nrc_y> static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -5512,7 +5584,8 @@ struct QFBaseBF16 { using Data = __m512bh; using Acc = __m512; static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); } - static inline Acc acc(Acc prev, const Data& y, const Data& x) { + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + static inline Acc acc(Acc prev, Data y, Data x) { return _mm512_dpbf16_ps(prev, y, x); } static inline Acc acc_first(const Data& y, const Data& x) { @@ -5563,6 +5636,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix])); } + template <int nrc_y> void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { constexpr int k_nx = nrc_y <= 2 ? 8 : 5; @@ -5777,6 +5851,17 @@ void set_mul_mat_bf16(MulMat& mm) { mm.funcs[3] = mul_mat_fX_fY_T<4>; mm.funcs[4] = mul_mat_fX_fY_T<5>; } +void set_mul_mat_bf16_r16(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_bf16_r16_bf16<1>; + mm.funcs[1] = mul_mat_bf16_r16_bf16<2>; + mm.funcs[2] = mul_mat_bf16_r16_bf16<3>; + mm.funcs[3] = mul_mat_bf16_r16_bf16<4>; + mm.funcs[4] = mul_mat_bf16_r16_bf16<5>; + mm.funcs[5] = mul_mat_bf16_r16_bf16<6>; + mm.funcs[6] = mul_mat_bf16_r16_bf16<7>; + mm.funcs[7] = mul_mat_bf16_r16_bf16<8>; +} #endif bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { @@ -5794,6 +5879,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { return true; } + if (typeA == GGML_TYPE_BF16_R16) { + if (ne00 % 16) return false; + switch (typeB) { +#ifdef __AVX512BF16__ + case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break; +#endif + default: return false; + } + return true; + } + if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) { if (ne00 % 4) return false; } |