From 6b4167164cdde5dd21b3786bebc0688f5023f326 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 24 Jul 2024 08:02:56 +0200 Subject: iqk_mul_mat(NEON): special case for n not divisible by 8 Else fp16 PP performance drops by nearly a factor of 2 compared to what we had before. --- iqk_mul_mat.cpp | 61 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 22 deletions(-) (limited to 'iqk_mul_mat.cpp') diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index c83d2d84..b29e182b 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4235,11 +4235,10 @@ template struct QF16 final : public QF16Base { const __fp16 * y[nrc_y]; }; -template +template IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { assert(n%QF16Base::k_step == 0); int nb = n/QF16Base::k_step; - int nb4 = n/4; QF16 y(info); QF16 x(cx + ix0*bx, bx); QF16Base::Data xv[nrc_x]; @@ -4264,15 +4263,18 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0 for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); } } - for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) { - yv = y.load_tail(0, i); - for (int ix = 0; ix < nrc_x; ++ix) { - xv[ix] = x.load_tail(ix, i); - acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]); - } - for (int iy = 1; iy < nrc_y; ++iy) { - yv = y.load_tail(iy, i); - for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + if constexpr (!is_multiple_of_k_step) { + int nb4 = n/4; + for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) { + yv = y.load_tail(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load_tail(ix, i); + acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load_tail(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + } } } for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix])); @@ -4283,17 +4285,32 @@ void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, GGML_ASSERT(n%4 == 0); constexpr int k_nx = 5; const char * cx = (const char *)vx; - for (int ix = 0; ix < nrc_x/k_nx; ++ix) { - mul_mat_f16_f16_NxN(n, cx, bx, ix*k_nx, info); - } - int last_x = k_nx*(nrc_x/k_nx); - if (last_x == nrc_x) return; - int nx = nrc_x - last_x; - switch (nx) { - case 1: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; - case 2: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; - case 3: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; - case 4: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + if (n%QF16Base::k_step == 0) { + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f16_NxN(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 3: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + } + } else { + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f16_NxN(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 3: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + } } } -- cgit v1.2.3