summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp98
1 files changed, 97 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 75e5c3c1..d1af9fe8 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -96,6 +96,11 @@ struct DataInfo {
_mm256_storeu_ps(dst_row(iy) + ix, result);
}
#endif
+#ifdef __AVX512F__
+ inline void store(int ix, int iy, __m512 result) const {
+ _mm512_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
#ifdef __ARM_NEON
inline void store(int ix, int iy, float32x4_t result) const {
vst1q_f32(dst_row(iy) + ix, result);
@@ -179,6 +184,7 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_Q8_K_R8: return 8;
+ case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
}
@@ -3876,6 +3882,72 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
+#ifdef __AVX512BF16__
+template <int nrc_y>
+static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%16 == 0);
+ const ggml_bf16_t * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
+ for (int ix = 0; ix < nrc_x/32; ++ix) {
+ __m512 acc[2*nrc_y] = {};
+ __m512bh qx[8];
+ const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
+ const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
+ qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
+ qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
+ qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
+ qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ //auto y = _mm512_broadcast_i32x4(y128);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(32*ix+ 0, iy, acc[2*iy+0]);
+ info.store(32*ix+16, iy, acc[2*iy+1]);
+ }
+ }
+ for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
+ __m512 acc[nrc_y] = {};
+ __m512bh qx[4];
+ const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ }
+ }
+}
+#endif
+
template <int nrc_y>
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -5512,7 +5584,8 @@ struct QFBaseBF16 {
using Data = __m512bh;
using Acc = __m512;
static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ static inline Acc acc(Acc prev, Data y, Data x) {
return _mm512_dpbf16_ps(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
@@ -5563,6 +5636,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
}
+
template <int nrc_y>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
@@ -5777,6 +5851,17 @@ void set_mul_mat_bf16(MulMat& mm) {
mm.funcs[3] = mul_mat_fX_fY_T<4>;
mm.funcs[4] = mul_mat_fX_fY_T<5>;
}
+void set_mul_mat_bf16_r16(MulMat& mm) {
+ for (auto& f : mm.funcs) f = nullptr;
+ mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
+ mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
+ mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
+ mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
+ mm.funcs[4] = mul_mat_bf16_r16_bf16<5>;
+ mm.funcs[5] = mul_mat_bf16_r16_bf16<6>;
+ mm.funcs[6] = mul_mat_bf16_r16_bf16<7>;
+ mm.funcs[7] = mul_mat_bf16_r16_bf16<8>;
+}
#endif
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
@@ -5794,6 +5879,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
return true;
}
+ if (typeA == GGML_TYPE_BF16_R16) {
+ if (ne00 % 16) return false;
+ switch (typeB) {
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break;
+#endif
+ default: return false;
+ }
+ return true;
+ }
+
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
if (ne00 % 4) return false;
}