From 85c5a1a99569ccc00c280835fe3a69b4af02c43b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 15 Dec 2024 09:54:21 +0100 Subject: BF16_R16 - 16 interleaved bf16 rows (#142) * Not working bf16_r4 * Adding bf16_r8 Small performance gain compared to bf16 - 258 t/s vs 234 t/s. I guess, this is still sub-obtimal. * bf16_rx: Very slightly faster by interleaving 16 rows 258 t/s -> 263 t/s * Rename bf16_r4 to bf16_r16 We are interleaving 16 rows now. * Cleanup unused stuff --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 1 + ggml/src/ggml.c | 20 +++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 98 ++++++++++++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_quantize.cpp | 38 ++++++++++++++++- ggml/src/iqk/iqk_quantize.h | 3 ++ 5 files changed, 158 insertions(+), 2 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index f12c9fe8..0b157295 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15209,6 +15209,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_Q6_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 772c70c4..51ef6eb2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1231,6 +1231,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_BF16_R16] = { + .type_name = "bf16_r16", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + //.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + //.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + //.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, + //.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", .blck_size = QK4_0, @@ -4110,6 +4123,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; + case GGML_FTYPE_MOSTLY_BF16_R16: wtype = GGML_TYPE_BF16_R16;break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; @@ -15748,6 +15762,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_BF16_R16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -22651,6 +22666,11 @@ size_t ggml_quantize_chunk( ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n); result = n * elemsize; } break; + case GGML_TYPE_BF16_R16: + { + repack_f32_bf16_r16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row); + result = nrows * row_size; + } break; case GGML_TYPE_F32: { size_t elemsize = sizeof(float); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 75e5c3c1..d1af9fe8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -96,6 +96,11 @@ struct DataInfo { _mm256_storeu_ps(dst_row(iy) + ix, result); } #endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif #ifdef __ARM_NEON inline void store(int ix, int iy, float32x4_t result) const { vst1q_f32(dst_row(iy) + ix, result); @@ -179,6 +184,7 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_Q8_K_R8: return 8; + case GGML_TYPE_BF16_R16: return 16; default: return 1; } } @@ -3876,6 +3882,72 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } +#ifdef __AVX512BF16__ +template +static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%16 == 0); + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + for (int ix = 0; ix < nrc_x/32; ++ix) { + __m512 acc[2*nrc_y] = {}; + __m512bh qx[8]; + const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx); + const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3); + qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0); + qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1); + qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2); + qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + //auto y = _mm512_broadcast_i32x4(y128); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(32*ix+ 0, iy, acc[2*iy+0]); + info.store(32*ix+16, iy, acc[2*iy+1]); + } + } + for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) { + __m512 acc[nrc_y] = {}; + __m512bh qx[4]; + const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + } + } +} +#endif + template static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -5512,7 +5584,8 @@ struct QFBaseBF16 { using Data = __m512bh; using Acc = __m512; static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); } - static inline Acc acc(Acc prev, const Data& y, const Data& x) { + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + static inline Acc acc(Acc prev, Data y, Data x) { return _mm512_dpbf16_ps(prev, y, x); } static inline Acc acc_first(const Data& y, const Data& x) { @@ -5563,6 +5636,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix])); } + template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { constexpr int k_nx = nrc_y <= 2 ? 8 : 5; @@ -5777,6 +5851,17 @@ void set_mul_mat_bf16(MulMat& mm) { mm.funcs[3] = mul_mat_fX_fY_T<4>; mm.funcs[4] = mul_mat_fX_fY_T<5>; } +void set_mul_mat_bf16_r16(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_bf16_r16_bf16<1>; + mm.funcs[1] = mul_mat_bf16_r16_bf16<2>; + mm.funcs[2] = mul_mat_bf16_r16_bf16<3>; + mm.funcs[3] = mul_mat_bf16_r16_bf16<4>; + mm.funcs[4] = mul_mat_bf16_r16_bf16<5>; + mm.funcs[5] = mul_mat_bf16_r16_bf16<6>; + mm.funcs[6] = mul_mat_bf16_r16_bf16<7>; + mm.funcs[7] = mul_mat_bf16_r16_bf16<8>; +} #endif bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { @@ -5794,6 +5879,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { return true; } + if (typeA == GGML_TYPE_BF16_R16) { + if (ne00 % 16) return false; + switch (typeB) { +#ifdef __AVX512BF16__ + case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break; +#endif + default: return false; + } + return true; + } + if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) { if (ne00 % 4) return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index de8c0d99..abe81858 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -4708,7 +4708,7 @@ static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8 } } } - x += 4*nblock; + x += 8*nblock; y += nblock; } } @@ -4759,3 +4759,39 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } +// +// ========================================= bf16_r4 +// +namespace { +inline ggml_bf16_t to_bf16(const float& x) { + union { float f; uint32_t u; } helper; + helper.f = x; + return ggml_bf16_t{(uint16_t)(helper.u >> 16)}; +} +inline ggml_bf16_t to_bf16(const ggml_bf16_t& x) { return x; } +template +void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) { + GGML_ASSERT(nrows%16 == 0); + GGML_ASSERT(n_per_row%2 == 0); + for (int row = 0; row < nrows; row += 16) { + for (int k = 0; k < 16; ++k) { + auto x8 = x + k*n_per_row; + for (int ib = 0; ib < n_per_row/2; ++ib) { + y[32*ib + 2*k + 0] = to_bf16(x8[2*ib+0]); + y[32*ib + 2*k + 1] = to_bf16(x8[2*ib+1]); + } + } + x += 16*n_per_row; + y += 16*n_per_row; + } +} +} + +void repack_f32_bf16_r16(const void * src, void * dst, int64_t nrows, int64_t n_per_row) { + repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst); +} + +void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) { + repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst); +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 753bbdb5..e8721a5e 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -158,6 +158,9 @@ void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); +void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); + #ifdef __cplusplus } #endif -- cgit v1.2.3