summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp132
1 files changed, 77 insertions, 55 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 721439a6..2b805088 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -2154,71 +2154,92 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
inline static int block_size() { return QK4_1; }
};
-template <int nrc> struct QF32 {
+struct QF32Base {
+ constexpr static int k_step = 8;
+ using Data = __m256;
+ using Acc = __m256;
+ static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
+ static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return _mm256_fmadd_ps(y, x, prev);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm256_mul_ps(y, x);
+ }
+ static inline float hsum(Acc acc) {
+ return hsum_float_8(acc);
+ }
+};
+template <int nrc> struct QF32y final : public QF32Base {
constexpr static int nrc_y = nrc;
- QF32(const DataInfo& info) {
+ QF32y(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
}
-#ifdef __AVX512F__
- IQK_ALWAYS_INLINE __m512 load64(int iy, int i) const { return _mm512_loadu_ps(y[iy] + 16*i); }
-#endif
- IQK_ALWAYS_INLINE __m256 load1(int iy, int i) const { return _mm256_loadu_ps(y[iy] + 8*i); }
-
- const float * y[nrc_y];
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ const float * y[nrc_y];
};
-
-template <int nrc_y>
-void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%8 == 0);
- constexpr int k_nx = 4;
- int nb = n/8;
- QF32<nrc_y> qf32(info);
- const __m128i * x[k_nx];
- __m256 acc[k_nx*nrc_y];
- __m256 xv[k_nx];
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- int ix0 = k_nx*ix;
- for (int kx = 0; kx < k_nx; ++kx) {
- x[kx] = (const __m128i *)((const char *)vx + (ix0 + kx)*bx);
- xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++));
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = qf32.load1(iy, 0);
- for (int kx = 0; kx < k_nx; ++kx) acc[k_nx*iy + kx] = _mm256_mul_ps(yv, xv[kx]);
- }
- for (int i = 1; i < nb; ++i) {
- for (int kx = 0; kx < k_nx; ++kx) xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = qf32.load1(iy, i);
- for (int kx = 0; kx < k_nx; ++kx) acc[k_nx*iy + kx] = _mm256_fmadd_ps(yv, xv[kx], acc[k_nx*iy + kx]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
- }
+template <int nrc> struct QF32x final : public QF32Base {
+ constexpr static int nrc_x = nrc;
+ QF32x(const char * cx, size_t bx) {
+ for (int ix = 0; ix < nrc_x; ++ix) x[ix] = (const ggml_half *)(cx + ix*bx);
}
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
+ IQK_ALWAYS_INLINE Data load1(int ix, int i) const { return load(x[ix] + k_step*i); }
+ const ggml_half * x[nrc_x];
+};
- // handle remaining rows
- int ix0 = last_x; int nx = nrc_x - last_x;
- for (int kx = 0; kx < nx; ++kx) {
- x[kx] = (const __m128i *)((const char *)vx + (ix0 + kx)*bx);
- xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++));
+template <int nrc_y, int nrc_x>
+IQK_NOINLINE void mul_mat_f16_f32_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF32Base::k_step;
+ QF32y<nrc_y> y(info);
+ QF32x<nrc_x> x(cx + ix0*bx, bx);
+ QF32Base::Data xv[nrc_x];
+ QF32Base::Acc acc[nrc_x*nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QF32Base::acc_first(yv, xv[ix]);
}
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = qf32.load1(iy, 0);
- for (int kx = 0; kx < nx; ++kx) acc[nx*iy + kx] = _mm256_mul_ps(yv, xv[kx]);
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
- for (int kx = 0; kx < nx; ++kx) xv[kx] = _mm256_cvtph_ps(_mm_loadu_si128(x[kx]++));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = qf32.load1(iy, i);
- for (int kx = 0; kx < nx; ++kx) acc[nx*iy + kx] = _mm256_fmadd_ps(yv, xv[kx], acc[nx*iy + kx]);
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QF32Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF32Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
- for (int iy = 0; iy < nrc_y; ++iy) {
- for (int kx = 0; kx < nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[nx*iy+kx]));
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF32Base::hsum(acc[nrc_x*iy+ix]));
+}
+
+template <int nrc_y>
+void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QF32Base::k_step == 0);
+#ifdef __AVX512F__
+ constexpr int k_nx = 5;
+#else
+ constexpr int k_nx = 3;
+#endif
+ const char * cx = (const char *)vx;
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_f16_f32_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_f16_f32_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_f16_f32_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
+#ifdef __AVX512F__
+ case 3: mul_mat_f16_f32_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_f16_f32_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
+#endif
}
}
@@ -2370,8 +2391,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
mm.funcs[0] = mul_mat_f16_f32_T<1>;
mm.funcs[1] = mul_mat_f16_f32_T<2>;
mm.funcs[2] = mul_mat_f16_f32_T<3>;
-#ifdef __AVX512F__
mm.funcs[3] = mul_mat_f16_f32_T<4>;
+#ifdef __AVX512F__
+ mm.funcs[4] = mul_mat_f16_f32_T<5>;
#endif
row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00);
return true;