summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-25 08:37:13 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-25 08:37:13 +0200
commitc2158c15d9d3d916e564411f19afd8138dc8317c (patch)
treec08658bfce49833234717c7091e2a9fb88c73a4b
parent28fb349db49d090c9a430076dc454fa8c878c2ec (diff)
iqk_mul_mat(NEON): adding forgotten fp16 matrix x vector implementation
-rw-r--r--iqk_mul_mat.cpp97
1 files changed, 96 insertions, 1 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 1dee6ef3..3915d44c 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -4223,6 +4223,7 @@ template <int nrc> struct QF16 final : public QF16Base {
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
+ IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
const __fp16 * y[nrc_y];
};
@@ -4305,6 +4306,100 @@ void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
+template <int nrc_x, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF16Base::k_step;
+ QF16<1> y(info);
+ QF16<nrc_x> x(cx + ix0*bx, bx);
+ QF16Base::Acc acc[4*nrc_x];
+ auto yv = y.loadx(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, k);
+ acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
+ }
+ }
+ for (int i = 1; i < nb/4; ++i) {
+ yv = y.loadx(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, 4*i+k);
+ acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
+ }
+ }
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto yv1 = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load1(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ if constexpr (!is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
+ auto yv1 = y.load_tail(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load_tail(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
+ auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
+ info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
+ }
+}
+
+// At least on my M2-Max the version below, which dows the multiplication row-by-row, is faster.
+// But let's keep this version commented out for now.
+//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+// GGML_ASSERT(n%4 == 0);
+// constexpr int k_nx = 2;
+// const char * cx = (const char *)vx;
+// if (n%QF16Base::k_step == 0) {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
+// }
+// } else {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
+// }
+// }
+//}
+
+void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ const char * cx = (const char *)vx;
+ if (n%QF16Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
+ }
+ }
+}
+
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
@@ -4549,7 +4644,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
if (ne00%4) return false;
for (auto& f : m.funcs) f = nullptr;
- m.funcs[0] = mul_mat_f16_f16_T<1>;
+ m.funcs[0] = mul_mat_f16_f16_1;
m.funcs[1] = mul_mat_f16_f16_T<2>;
m.funcs[2] = mul_mat_f16_f16_T<3>;
m.funcs[3] = mul_mat_f16_f16_T<4>;