diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-10 09:53:26 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:50 +0300 |
commit | ae1e77c5dee9b513e0b710075ef6713ede821b3c (patch) | |
tree | fa8838c2e15bc5723503510d6e46f1da1473accf | |
parent | 9386b499181a1d89c39e3a8114ef3255e9d52e63 (diff) |
iqk_mul_mat: better fp16 for AVX2
Basically use what I did for Arm.
Improves PP performance to 141.7 t/s up from 136 t/s
on the Ryzen-7950X (32 vector registers, so we use 5x5 tiling).
This is now 10% faster than tinyBLAS.
There is a minor improvement also on the Ryzen-5975WX
(16 vector registers, so we use 4x3 tiling): we get
138 t/s up from 136 t/s. tinyBLAS is at 132 t/s.
-rw-r--r-- | iqk_mul_mat.cpp | 132 |
1 files changed, 77 insertions, 55 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 721439a6..2b805088 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -2154,71 +2154,92 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_ inline static int block_size() { return QK4_1; } }; -template <int nrc> struct QF32 { +struct QF32Base { + constexpr static int k_step = 8; + using Data = __m256; + using Acc = __m256; + static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } + static inline Data load(const float * x) { return _mm256_loadu_ps(x); } + static inline Acc acc(Acc prev, const Data& y, const Data& x) { + return _mm256_fmadd_ps(y, x, prev); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return _mm256_mul_ps(y, x); + } + static inline float hsum(Acc acc) { + return hsum_float_8(acc); + } +}; +template <int nrc> struct QF32y final : public QF32Base { constexpr static int nrc_y = nrc; - QF32(const DataInfo& info) { + QF32y(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); } -#ifdef __AVX512F__ - IQK_ALWAYS_INLINE __m512 load64(int iy, int i) const { return _mm512_loadu_ps(y[iy] + 16*i); } -#endif - IQK_ALWAYS_INLINE __m256 load1(int iy, int i) const { return _mm256_loadu_ps(y[iy] + 8*i); } - - const float * y[nrc_y]; + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + const float * y[nrc_y]; }; - -template <int nrc_y> -void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%8 == 0); - constexpr int k_nx = 4; - int nb = n/8; - QF32<nrc_y> qf32(info); - const __m128i * x[k_nx]; - __m256 acc[k_nx*nrc_y]; - __m256 xv[k_nx]; - for (int ix = 0; ix < nrc_x/k_nx; ++ix) { - int ix0 = k_nx*ix; - for (int kx = 0; kx < k_nx; ++kx) { - x[kx] = (const __m128i *)((const char *)vx + (ix0 + kx)*bx); - xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++)); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto yv = qf32.load1(iy, 0); - for (int kx = 0; kx < k_nx; ++kx) acc[k_nx*iy + kx] = _mm256_mul_ps(yv, xv[kx]); - } - for (int i = 1; i < nb; ++i) { - for (int kx = 0; kx < k_nx; ++kx) xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto yv = qf32.load1(iy, i); - for (int kx = 0; kx < k_nx; ++kx) acc[k_nx*iy + kx] = _mm256_fmadd_ps(yv, xv[kx], acc[k_nx*iy + kx]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx])); - } +template <int nrc> struct QF32x final : public QF32Base { + constexpr static int nrc_x = nrc; + QF32x(const char * cx, size_t bx) { + for (int ix = 0; ix < nrc_x; ++ix) x[ix] = (const ggml_half *)(cx + ix*bx); } - int last_x = k_nx*(nrc_x/k_nx); - if (last_x == nrc_x) return; + IQK_ALWAYS_INLINE Data load1(int ix, int i) const { return load(x[ix] + k_step*i); } + const ggml_half * x[nrc_x]; +}; - // handle remaining rows - int ix0 = last_x; int nx = nrc_x - last_x; - for (int kx = 0; kx < nx; ++kx) { - x[kx] = (const __m128i *)((const char *)vx + (ix0 + kx)*bx); - xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++)); +template <int nrc_y, int nrc_x> +IQK_NOINLINE void mul_mat_f16_f32_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QF16Base::k_step == 0); + int nb = n/QF32Base::k_step; + QF32y<nrc_y> y(info); + QF32x<nrc_x> x(cx + ix0*bx, bx); + QF32Base::Data xv[nrc_x]; + QF32Base::Acc acc[nrc_x*nrc_y]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QF32Base::acc_first(yv, xv[ix]); } - for (int iy = 0; iy < nrc_y; ++iy) { - auto yv = qf32.load1(iy, 0); - for (int kx = 0; kx < nx; ++kx) acc[nx*iy + kx] = _mm256_mul_ps(yv, xv[kx]); + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc_first(yv, xv[ix]); } for (int i = 1; i < nb; ++i) { - for (int kx = 0; kx < nx; ++kx) xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto yv = qf32.load1(iy, i); - for (int kx = 0; kx < nx; ++kx) acc[nx*iy + kx] = _mm256_fmadd_ps(yv, xv[kx], acc[nx*iy + kx]); + yv = y.load1(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QF32Base::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); } } - for (int iy = 0; iy < nrc_y; ++iy) { - for (int kx = 0; kx < nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[nx*iy+kx])); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF32Base::hsum(acc[nrc_x*iy+ix])); +} + +template <int nrc_y> +void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QF32Base::k_step == 0); +#ifdef __AVX512F__ + constexpr int k_nx = 5; +#else + constexpr int k_nx = 3; +#endif + const char * cx = (const char *)vx; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f32_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f32_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f32_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break; +#ifdef __AVX512F__ + case 3: mul_mat_f16_f32_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f32_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break; +#endif } } @@ -2370,8 +2391,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int mm.funcs[0] = mul_mat_f16_f32_T<1>; mm.funcs[1] = mul_mat_f16_f32_T<2>; mm.funcs[2] = mul_mat_f16_f32_T<3>; -#ifdef __AVX512F__ mm.funcs[3] = mul_mat_f16_f32_T<4>; +#ifdef __AVX512F__ + mm.funcs[4] = mul_mat_f16_f32_T<5>; #endif row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00); return true; |