diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-01-22 12:13:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-22 12:13:55 +0200 |
commit | dbf5d31d01e14a0ba692efafca5e4d66ada60b8a (patch) | |
tree | 64c7022a940a48c5f3153429758a9e1083f1edda | |
parent | 6d23495b9bb8945c6ec1c38ced4b44180fbac3c6 (diff) |
Better BF16 support on AVX2 (#175)
* Adding BF16 support for AVX2
PP performance is the same as fp16 (~153 t/s on Ryzen-5975WX),
but TG is quite a bit lower (3.65 t/s vs 4.72 t/s at 8 threads).
Why?
* Slightly faster fp16/bf16 gemv on AVX2
It still saturates at the same lower peformance for bf16
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 109ac08e..7ddaee2a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6970,6 +6970,9 @@ struct QFBase { using Acc = __m256; static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } static inline Data load(const float * x) { return _mm256_loadu_ps(x); } + static inline Data load(const ggml_bf16_t * x) { + return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16)); + } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } @@ -7003,6 +7006,9 @@ struct QFBase { #endif static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); } static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); } + static inline __m128 load128(const ggml_bf16_t * x) { + return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16)); + } }; template <typename Float, int nrc_in> struct QFT final : public QFBase { constexpr static int nrc = nrc_in; @@ -7142,7 +7148,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in #ifdef __AVX512F__ constexpr int k_nx = 5; #else - constexpr int k_nx = 2; + constexpr int k_nx = nrc_y == 1 ? 4 : 2; #endif const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { @@ -7151,14 +7157,26 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; +#ifdef __AVX512F__ switch (nx) { case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break; -#ifdef __AVX512F__ case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break; case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break; -#endif } +#else + if constexpr (nrc_y == 1) { + switch (nx) { + case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break; + case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break; + case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break; + } + } else { + switch (nx) { + case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break; + } + } +#endif } #ifdef __AVX512BF16__ @@ -7456,6 +7474,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { switch (typeB) { #ifdef __AVX512BF16__ case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break; +#else + case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(mm); break; + case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(mm); break; #endif default: return false; } |