summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-09-16 16:47:36 +0300
committerGitHub <noreply@github.com>2024-09-16 16:47:36 +0300
commit2874b984006c6c8d0691ce000dcd9ca2cf9ff6fd (patch)
tree4244cf6b022a6eb728f5d0eb3ba94a739681e345
parent20f3e6fd2de6378d2a598b48edce369642bf2ee8 (diff)
iqk_mul_mat(ARM_NEON): adding bf16 support (#41)
It looks like ArmV8 ISA has support for bf16, but my M2 Max does not have it, so resorting to bf16 -> f32 conversion and computations in f32. This is 2x slower than f16, but 8x better compared to what I get if I try to run a bf16 model on the M2 (NEON and Metal). Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp145
1 files changed, 144 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 82a55af4..7543d895 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -5774,6 +5774,7 @@ struct QF16Base {
}
};
template <int nrc> struct QF16 final : public QF16Base {
+ using Base = QF16Base;
constexpr static int nrc_y = nrc;
QF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
@@ -5787,6 +5788,103 @@ template <int nrc> struct QF16 final : public QF16Base {
const __fp16 * y[nrc_y];
};
+struct QBF16Base {
+ constexpr static int k_step = 4;
+ using Data = float32x4_t;
+ using Acc = float32x4_t;
+ static inline Data load(const uint16_t * x) { return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16(x)), 16)); }
+ static inline Data load4(const uint16_t * x) { return load(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return vfmaq_f32(prev, y, x);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return vmulq_f32(y, x);
+ }
+ static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
+};
+template <int nrc> struct QBF16 final : public QBF16Base {
+ using Base = QBF16Base;
+ constexpr static int nrc_y = nrc;
+ QBF16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)info.src1_row(iy);
+ }
+ QBF16(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
+ const uint16_t * y[nrc_y];
+};
+
+struct QF32Base {
+ constexpr static int k_step = 4;
+ using Data = float32x4_t;
+ using Acc = float32x4_t;
+ static inline Data load(const float * x) { return vld1q_f32(x); }
+ static inline Data load4(const float * x) { return load(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) { return vfmaq_f32(prev, y, x); }
+ static inline Acc acc_first(const Data& y, const Data& x) { return vmulq_f32(y, x); }
+ static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
+};
+template <int nrc> struct QF32 final : public QF32Base {
+ using Base = QF32Base;
+ constexpr static int nrc_y = nrc;
+ QF32(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
+ }
+ QF32(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
+ const float * y[nrc_y];
+};
+
+template <typename Qy, typename Qx, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_Qx_Qy_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ GGML_ASSERT(Qx::Base::k_step == Qy::Base::k_step);
+ int nb = n/Qx::Base::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ typename Qx::Base::Data xv[Qx::nrc_y];
+ typename Qx::Base::Acc acc[Qx::nrc_y*Qy::nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = Qx::Base::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
+ }
+ }
+ if constexpr (Qx::Base::k_step > 4 && !is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (Qx::Base::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
+ }
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc_y; ++iy) for (int ix = 0; ix < Qx::nrc_y; ++ix) info.store(ix0+ix, iy, Qx::Base::hsum(acc[Qx::nrc_y*iy+ix]));
+}
+
template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
@@ -5832,6 +5930,40 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
}
+template <typename Qy, template<int> typename Qx>
+void mul_mat_Qx_Qy_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ constexpr int k_nx = 5;
+ const char * cx = (const char *)vx;
+ if (n%Qx<k_nx>::Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, true>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, true>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, true>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, true>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, true>(n, cx, bx, last_x, info); break;
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, false>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, false>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, false>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, false>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, false>(n, cx, bx, last_x, info); break;
+ }
+ }
+}
+
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
@@ -5913,7 +6045,7 @@ IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0
}
}
-// At least on my M2-Max the version below, which dows the multiplication row-by-row, is faster.
+// At least on my M2-Max the version below, which does the multiplication row-by-row, is faster.
// But let's keep this version commented out for now.
//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(n%4 == 0);
@@ -6231,6 +6363,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
return true;
}
+ if (typeA == GGML_TYPE_BF16 && typeB == GGML_TYPE_F32) {
+ if (ne00%4) return false;
+ for (auto& f : m.funcs) f = nullptr;
+ m.funcs[0] = mul_mat_Qx_Qy_T<QF32<1>, QBF16>;
+ m.funcs[1] = mul_mat_Qx_Qy_T<QF32<2>, QBF16>;
+ m.funcs[2] = mul_mat_Qx_Qy_T<QF32<3>, QBF16>;
+ m.funcs[3] = mul_mat_Qx_Qy_T<QF32<4>, QBF16>;
+ m.funcs[4] = mul_mat_Qx_Qy_T<QF32<5>, QBF16>;
+ return true;
+ }
+
auto expected_Btype = GGML_TYPE_Q8_K;
switch (typeA) {