diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index dc43d0fc..721439a6 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -3761,6 +3761,93 @@ static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInf mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); } +struct QF16Base { + constexpr static int k_step = 8; + using Data = float16x8_t; + using Acc = float16x8_t; + static inline Data load(const __fp16 * x) { return vld1q_f16(x); } + static inline Acc acc(Acc prev, const Data& y, const Data& x) { + return vfmaq_f16(prev, y, x); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return vmulq_f16(y, x); + } + //constexpr static int k_step = 16; + //using Data = float16x8x2_t; + //static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); } + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + // return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]); + //} + //static inline Acc acc_first(const Data& y, const Data& x) { + // return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]); + //} + static inline float hsum(Acc acc) { + float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc))); + return vaddvq_f32(sum); + } +}; +template <int nrc> struct QF16 final : public QF16Base { + constexpr static int nrc_y = nrc; + QF16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy); + } + QF16(const char * cx, size_t bx) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx); + } + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + const __fp16 * y[nrc_y]; +}; + +template <int nrc_y, int nrc_x> +IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QF16Base::k_step == 0); + int nb = n/QF16Base::k_step; + QF16<nrc_y> y(info); + QF16<nrc_x> x(cx + ix0*bx, bx); + QF16Base::Data xv[nrc_x]; + QF16Base::Acc acc[nrc_x*nrc_y]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QF16Base::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix])); +} + +template <int nrc_y> +void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QF16Base::k_step == 0); + constexpr int k_nx = 5; + const char * cx = (const char *)vx; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f16_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f16_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f16_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break; + case 3: mul_mat_f16_f16_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f16_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break; + } +} + template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> || std::is_same_v<Dequantizer, DequantizerQ80>) { @@ -3798,6 +3885,19 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + if (typeA == GGML_TYPE_F16) { + for (auto& f : m.funcs) f = nullptr; + m.funcs[0] = mul_mat_f16_f16_T<1>; + m.funcs[1] = mul_mat_f16_f16_T<2>; + m.funcs[2] = mul_mat_f16_f16_T<3>; + m.funcs[3] = mul_mat_f16_f16_T<4>; + m.funcs[4] = mul_mat_f16_f16_T<5>; + //m.funcs[5] = mul_mat_f16_f16_T<6>; + //m.funcs[6] = mul_mat_f16_f16_T<7>; + row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00); + return true; + } + switch (typeA) { case GGML_TYPE_Q2_K: MulMat::set_functions<DequantizerQ2K>(m); |