From 81cf6990f512e82c2c89ba7f89a15c3d98172f84 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 10 Jun 2024 16:43:42 +0300 Subject: iqk_mul_mat: be able to handle any f16/f32 combination on AVX2 But only turning on f16 x f32 and f32 x f16 for now. --- iqk_mul_mat.cpp | 108 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 60 insertions(+), 48 deletions(-) (limited to 'iqk_mul_mat.cpp') diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 8f0b9816..9934d2e6 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -2153,7 +2153,7 @@ struct Q5_1_Unpacker final : public Q_Unpacker struct QF32y final : public QF32Base { - constexpr static int nrc_y = nrc; - QF32y(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); +template struct QFT final : public QFBase { + constexpr static int nrc = nrc_in; + QFT(const DataInfo& info) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy); } - IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } - const float * y[nrc_y]; -}; -template struct QF32x final : public QF32Base { - constexpr static int nrc_x = nrc; - QF32x(const char * cx, size_t bx) { - for (int ix = 0; ix < nrc_x; ++ix) x[ix] = (const ggml_half *)(cx + ix*bx); + QFT(const char * cx, size_t bx) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx); } - IQK_ALWAYS_INLINE Data load1(int ix, int i) const { return load(x[ix] + k_step*i); } - const ggml_half * x[nrc_x]; -}; - -template -IQK_NOINLINE void mul_mat_f16_f32_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { - assert(n%QF16Base::k_step == 0); - int nb = n/QF32Base::k_step; - QF32y y(info); - QF32x x(cx + ix0*bx, bx); - QF32Base::Data xv[nrc_x]; - QF32Base::Acc acc[nrc_x*nrc_y]; + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + const Float * y[nrc]; +}; +//template using QF32 = QFT; +//template using QF16 = QFT; + +template +IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QFBase::k_step == 0); + int nb = n/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc]; auto yv = y.load1(0, 0); - for (int ix = 0; ix < nrc_x; ++ix) { + for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, 0); - acc[ix] = QF32Base::acc_first(yv, xv[ix]); + acc[ix] = QFBase::acc_first(yv, xv[ix]); } - for (int iy = 1; iy < nrc_y; ++iy) { + for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, 0); - for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc_first(yv, xv[ix]); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); } for (int i = 1; i < nb; ++i) { yv = y.load1(0, i); - for (int ix = 0; ix < nrc_x; ++ix) { + for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, i); - acc[ix] = QF32Base::acc(acc[ix], yv, xv[ix]); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); } - for (int iy = 1; iy < nrc_y; ++iy) { + for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, i); - for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); } } - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF32Base::hsum(acc[nrc_x*iy+ix])); + for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } - -template -void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%QF32Base::k_step == 0); +// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done +// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in +// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. +template +void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QFBase::k_step == 0); #ifdef __AVX512F__ constexpr int k_nx = 5; #else @@ -2244,17 +2243,17 @@ void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, #endif const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { - mul_mat_f16_f32_MxN(n, cx, bx, ix*k_nx, info); + mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); } int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; switch (nx) { - case 1: mul_mat_f16_f32_MxN(n, cx, bx, last_x, info); break; + case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; #ifdef __AVX512F__ - case 2: mul_mat_f16_f32_MxN(n, cx, bx, last_x, info); break; - case 3: mul_mat_f16_f32_MxN(n, cx, bx, last_x, info); break; - case 4: mul_mat_f16_f32_MxN(n, cx, bx, last_x, info); break; + case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; #endif } } @@ -2404,17 +2403,30 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int if (typeA == GGML_TYPE_F16) { for (auto& f : mm.funcs) f = nullptr; - mm.funcs[0] = mul_mat_f16_f32_T<1>; - mm.funcs[1] = mul_mat_f16_f32_T<2>; - mm.funcs[2] = mul_mat_f16_f32_T<3>; - mm.funcs[3] = mul_mat_f16_f32_T<4>; - mm.funcs[4] = mul_mat_f16_f32_T<5>; + mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>; + mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>; + mm.funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>; + mm.funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>; + mm.funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>; #ifndef __AVX512F__ - mm.funcs[5] = mul_mat_f16_f32_T<6>; + mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>; #endif row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00); return true; } + if (typeA == GGML_TYPE_F32) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>; + mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>; + mm.funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>; + mm.funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>; + mm.funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>; +#ifndef __AVX512F__ + mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>; +#endif + row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00); + return true; + } // Using the standard legacy quant template is slightly faster than tiling // as implemented in mul_mat_q80_q80_T // if (typeA == GGML_TYPE_Q8_0) { -- cgit v1.2.3