summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c20
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp98
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp38
-rw-r--r--ggml/src/iqk/iqk_quantize.h3
5 files changed, 158 insertions, 2 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index f12c9fe8..0b157295 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15209,6 +15209,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q6_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 772c70c4..51ef6eb2 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1231,6 +1231,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_BF16_R16] = {
+ .type_name = "bf16_r16",
+ .blck_size = 1,
+ .type_size = sizeof(ggml_bf16_t),
+ .is_quantized = false,
+ //.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
+ //.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+ //.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
+ //.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
+ .vec_dot_type = GGML_TYPE_BF16,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q4_0_4_4] = {
.type_name = "q4_0_4x4",
.blck_size = QK4_0,
@@ -4110,6 +4123,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
+ case GGML_FTYPE_MOSTLY_BF16_R16: wtype = GGML_TYPE_BF16_R16;break;
case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -15748,6 +15762,7 @@ static void ggml_compute_forward_clamp(
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
+ case GGML_TYPE_BF16_R16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -22651,6 +22666,11 @@ size_t ggml_quantize_chunk(
ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
result = n * elemsize;
} break;
+ case GGML_TYPE_BF16_R16:
+ {
+ repack_f32_bf16_r16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row);
+ result = nrows * row_size;
+ } break;
case GGML_TYPE_F32:
{
size_t elemsize = sizeof(float);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 75e5c3c1..d1af9fe8 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -96,6 +96,11 @@ struct DataInfo {
_mm256_storeu_ps(dst_row(iy) + ix, result);
}
#endif
+#ifdef __AVX512F__
+ inline void store(int ix, int iy, __m512 result) const {
+ _mm512_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
#ifdef __ARM_NEON
inline void store(int ix, int iy, float32x4_t result) const {
vst1q_f32(dst_row(iy) + ix, result);
@@ -179,6 +184,7 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_Q8_K_R8: return 8;
+ case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
}
@@ -3876,6 +3882,72 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
+#ifdef __AVX512BF16__
+template <int nrc_y>
+static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%16 == 0);
+ const ggml_bf16_t * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
+ for (int ix = 0; ix < nrc_x/32; ++ix) {
+ __m512 acc[2*nrc_y] = {};
+ __m512bh qx[8];
+ const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
+ const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
+ qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
+ qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
+ qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
+ qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ //auto y = _mm512_broadcast_i32x4(y128);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(32*ix+ 0, iy, acc[2*iy+0]);
+ info.store(32*ix+16, iy, acc[2*iy+1]);
+ }
+ }
+ for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
+ __m512 acc[nrc_y] = {};
+ __m512bh qx[4];
+ const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ }
+ }
+}
+#endif
+
template <int nrc_y>
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -5512,7 +5584,8 @@ struct QFBaseBF16 {
using Data = __m512bh;
using Acc = __m512;
static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ static inline Acc acc(Acc prev, Data y, Data x) {
return _mm512_dpbf16_ps(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
@@ -5563,6 +5636,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
}
+
template <int nrc_y>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
@@ -5777,6 +5851,17 @@ void set_mul_mat_bf16(MulMat& mm) {
mm.funcs[3] = mul_mat_fX_fY_T<4>;
mm.funcs[4] = mul_mat_fX_fY_T<5>;
}
+void set_mul_mat_bf16_r16(MulMat& mm) {
+ for (auto& f : mm.funcs) f = nullptr;
+ mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
+ mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
+ mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
+ mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
+ mm.funcs[4] = mul_mat_bf16_r16_bf16<5>;
+ mm.funcs[5] = mul_mat_bf16_r16_bf16<6>;
+ mm.funcs[6] = mul_mat_bf16_r16_bf16<7>;
+ mm.funcs[7] = mul_mat_bf16_r16_bf16<8>;
+}
#endif
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
@@ -5794,6 +5879,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
return true;
}
+ if (typeA == GGML_TYPE_BF16_R16) {
+ if (ne00 % 16) return false;
+ switch (typeB) {
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break;
+#endif
+ default: return false;
+ }
+ return true;
+ }
+
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
if (ne00 % 4) return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index de8c0d99..abe81858 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4708,7 +4708,7 @@ static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8
}
}
}
- x += 4*nblock;
+ x += 8*nblock;
y += nblock;
}
}
@@ -4759,3 +4759,39 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
+//
+// ========================================= bf16_r4
+//
+namespace {
+inline ggml_bf16_t to_bf16(const float& x) {
+ union { float f; uint32_t u; } helper;
+ helper.f = x;
+ return ggml_bf16_t{(uint16_t)(helper.u >> 16)};
+}
+inline ggml_bf16_t to_bf16(const ggml_bf16_t& x) { return x; }
+template <typename T>
+void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) {
+ GGML_ASSERT(nrows%16 == 0);
+ GGML_ASSERT(n_per_row%2 == 0);
+ for (int row = 0; row < nrows; row += 16) {
+ for (int k = 0; k < 16; ++k) {
+ auto x8 = x + k*n_per_row;
+ for (int ib = 0; ib < n_per_row/2; ++ib) {
+ y[32*ib + 2*k + 0] = to_bf16(x8[2*ib+0]);
+ y[32*ib + 2*k + 1] = to_bf16(x8[2*ib+1]);
+ }
+ }
+ x += 16*n_per_row;
+ y += 16*n_per_row;
+ }
+}
+}
+
+void repack_f32_bf16_r16(const void * src, void * dst, int64_t nrows, int64_t n_per_row) {
+ repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst);
+}
+
+void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) {
+ repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 753bbdb5..e8721a5e 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -158,6 +158,9 @@ void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
+void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
+
#ifdef __cplusplus
}
#endif