summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-15 09:54:21 +0100
committerGitHub <noreply@github.com>2024-12-15 09:54:21 +0100
commit85c5a1a99569ccc00c280835fe3a69b4af02c43b (patch)
treeda421487d5ddd0467b2bfd6cbbfb2666406c46f1
parent20758edcae65213b2f575b6d23dfea67ad9dd0e0 (diff)
BF16_R16 - 16 interleaved bf16 rows (#142)
* Not working bf16_r4 * Adding bf16_r8 Small performance gain compared to bf16 - 258 t/s vs 234 t/s. I guess, this is still sub-obtimal. * bf16_rx: Very slightly faster by interleaving 16 rows 258 t/s -> 263 t/s * Rename bf16_r4 to bf16_r16 We are interleaving 16 rows now. * Cleanup unused stuff --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c20
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp98
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp38
-rw-r--r--ggml/src/iqk/iqk_quantize.h3
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp13
9 files changed, 175 insertions, 2 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 0f906b83..5e5dd7c0 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -77,6 +77,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
+ { "BF16_R16", LLAMA_FTYPE_MOSTLY_BF16_R16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index f9ff97a7..7b7fde0d 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -420,6 +420,7 @@ extern "C" {
GGML_TYPE_Q6_K_R4 = 214,
GGML_TYPE_IQ4_NL_R4 = 220,
GGML_TYPE_IQ4_XS_R4 = 223,
+ GGML_TYPE_BF16_R16 = 230,
GGML_TYPE_Q6_0_R4 = 233,
GGML_TYPE_IQ2_BN_R4 = 335,
GGML_TYPE_IQ4_K_R4 = 339,
@@ -493,6 +494,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q6_K_R4 = 214, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors
+ GGML_FTYPE_MOSTLY_BF16_R16 = 224, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index f12c9fe8..0b157295 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15209,6 +15209,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q6_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 772c70c4..51ef6eb2 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1231,6 +1231,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_BF16_R16] = {
+ .type_name = "bf16_r16",
+ .blck_size = 1,
+ .type_size = sizeof(ggml_bf16_t),
+ .is_quantized = false,
+ //.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
+ //.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+ //.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
+ //.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
+ .vec_dot_type = GGML_TYPE_BF16,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q4_0_4_4] = {
.type_name = "q4_0_4x4",
.blck_size = QK4_0,
@@ -4110,6 +4123,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
+ case GGML_FTYPE_MOSTLY_BF16_R16: wtype = GGML_TYPE_BF16_R16;break;
case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -15748,6 +15762,7 @@ static void ggml_compute_forward_clamp(
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
+ case GGML_TYPE_BF16_R16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -22651,6 +22666,11 @@ size_t ggml_quantize_chunk(
ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
result = n * elemsize;
} break;
+ case GGML_TYPE_BF16_R16:
+ {
+ repack_f32_bf16_r16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row);
+ result = nrows * row_size;
+ } break;
case GGML_TYPE_F32:
{
size_t elemsize = sizeof(float);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 75e5c3c1..d1af9fe8 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -96,6 +96,11 @@ struct DataInfo {
_mm256_storeu_ps(dst_row(iy) + ix, result);
}
#endif
+#ifdef __AVX512F__
+ inline void store(int ix, int iy, __m512 result) const {
+ _mm512_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
#ifdef __ARM_NEON
inline void store(int ix, int iy, float32x4_t result) const {
vst1q_f32(dst_row(iy) + ix, result);
@@ -179,6 +184,7 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_Q8_K_R8: return 8;
+ case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
}
@@ -3876,6 +3882,72 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
+#ifdef __AVX512BF16__
+template <int nrc_y>
+static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%16 == 0);
+ const ggml_bf16_t * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
+ for (int ix = 0; ix < nrc_x/32; ++ix) {
+ __m512 acc[2*nrc_y] = {};
+ __m512bh qx[8];
+ const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
+ const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
+ qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
+ qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
+ qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
+ qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ //auto y = _mm512_broadcast_i32x4(y128);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(32*ix+ 0, iy, acc[2*iy+0]);
+ info.store(32*ix+16, iy, acc[2*iy+1]);
+ }
+ }
+ for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
+ __m512 acc[nrc_y] = {};
+ __m512bh qx[4];
+ const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ }
+ }
+}
+#endif
+
template <int nrc_y>
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -5512,7 +5584,8 @@ struct QFBaseBF16 {
using Data = __m512bh;
using Acc = __m512;
static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ static inline Acc acc(Acc prev, Data y, Data x) {
return _mm512_dpbf16_ps(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
@@ -5563,6 +5636,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
}
+
template <int nrc_y>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
@@ -5777,6 +5851,17 @@ void set_mul_mat_bf16(MulMat& mm) {
mm.funcs[3] = mul_mat_fX_fY_T<4>;
mm.funcs[4] = mul_mat_fX_fY_T<5>;
}
+void set_mul_mat_bf16_r16(MulMat& mm) {
+ for (auto& f : mm.funcs) f = nullptr;
+ mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
+ mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
+ mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
+ mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
+ mm.funcs[4] = mul_mat_bf16_r16_bf16<5>;
+ mm.funcs[5] = mul_mat_bf16_r16_bf16<6>;
+ mm.funcs[6] = mul_mat_bf16_r16_bf16<7>;
+ mm.funcs[7] = mul_mat_bf16_r16_bf16<8>;
+}
#endif
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
@@ -5794,6 +5879,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
return true;
}
+ if (typeA == GGML_TYPE_BF16_R16) {
+ if (ne00 % 16) return false;
+ switch (typeB) {
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break;
+#endif
+ default: return false;
+ }
+ return true;
+ }
+
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
if (ne00 % 4) return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index de8c0d99..abe81858 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4708,7 +4708,7 @@ static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8
}
}
}
- x += 4*nblock;
+ x += 8*nblock;
y += nblock;
}
}
@@ -4759,3 +4759,39 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
+//
+// ========================================= bf16_r4
+//
+namespace {
+inline ggml_bf16_t to_bf16(const float& x) {
+ union { float f; uint32_t u; } helper;
+ helper.f = x;
+ return ggml_bf16_t{(uint16_t)(helper.u >> 16)};
+}
+inline ggml_bf16_t to_bf16(const ggml_bf16_t& x) { return x; }
+template <typename T>
+void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) {
+ GGML_ASSERT(nrows%16 == 0);
+ GGML_ASSERT(n_per_row%2 == 0);
+ for (int row = 0; row < nrows; row += 16) {
+ for (int k = 0; k < 16; ++k) {
+ auto x8 = x + k*n_per_row;
+ for (int ib = 0; ib < n_per_row/2; ++ib) {
+ y[32*ib + 2*k + 0] = to_bf16(x8[2*ib+0]);
+ y[32*ib + 2*k + 1] = to_bf16(x8[2*ib+1]);
+ }
+ }
+ x += 16*n_per_row;
+ y += 16*n_per_row;
+ }
+}
+}
+
+void repack_f32_bf16_r16(const void * src, void * dst, int64_t nrows, int64_t n_per_row) {
+ repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst);
+}
+
+void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) {
+ repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 753bbdb5..e8721a5e 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -158,6 +158,9 @@ void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
+void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
+
#ifdef __cplusplus
}
#endif
diff --git a/include/llama.h b/include/llama.h
index e4d6ed3d..988ffec7 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -191,6 +191,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 335, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_BF16_R16 = 232, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 = 337, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index 035e5b1a..536b2f97 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3828,6 +3828,7 @@ struct llama_model_loader {
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
+ case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break;
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
@@ -4540,6 +4541,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
+ case LLAMA_FTYPE_MOSTLY_BF16_R16: return "BF16_R16";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
@@ -15833,6 +15835,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q8_0_R4) {
new_type = GGML_TYPE_Q8_0;
}
+ else if (new_type == GGML_TYPE_BF16_R16) {
+ new_type = GGML_TYPE_BF16;
+ }
}
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M ||
@@ -16228,6 +16233,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
+ case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break;
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
// K-quants
@@ -16520,6 +16526,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (quantize) {
new_type = default_type;
+ if (new_type == GGML_TYPE_BF16_R16 && strcmp(tensor->name, "token_embd.weight") == 0) {
+ new_type = GGML_TYPE_BF16;
+ }
// get more optimal quantization type based on the tensor shape, layer, etc.
if (!params->pure && ggml_is_quantized(default_type)) {
@@ -16680,6 +16689,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_BF16_R16) {
+ if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16;
+ else chunk_size_multiplier = 16;
+ }
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
fflush(stdout);