summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-14 13:53:50 +0300
committerGitHub <noreply@github.com>2024-09-14 13:53:50 +0300
commit064b99365c4426b83b09a518b29cc1ffc0250d04 (patch)
treeb725f675e3d7e7496f895e77f5af5b6bf82da8ab
parent43b934b19fec38219299b6e03bc9479143b593fd (diff)
Improve Q4_0 and Q8_0 performance on AVX2/Zen4 (#54)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c8
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp62
2 files changed, 65 insertions, 5 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 4fdf9c18..060d10c6 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -707,7 +707,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
+#if GGML_USE_IQK_MULMAT && defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#else
.vec_dot_type = GGML_TYPE_Q8_0,
+#endif
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
@@ -788,7 +792,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
+#if GGML_USE_IQK_MULMAT && defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#else
.vec_dot_type = GGML_TYPE_Q8_0,
+#endif
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 21fe99e1..8888534c 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2923,6 +2923,29 @@ struct ScaleHelperQ_0 {
template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
};
+template <int min_value>
+struct ScaleHelperQ_0_1 {
+ ggml_half scales8[4];
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
+ auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
+ return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
+ }
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm_mul256_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ float d = GGML_FP16_TO_FP32(y->d);
+ return std::make_pair(d, -d*float(min_value));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+ const __m128 min = _mm_set1_ps(float(-min_value));
+};
+
struct ScaleHelperQ8_1 {
template <typename Q>
inline __m256 prepare4(const Q * y) {
@@ -3044,6 +3067,7 @@ using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
using Sum4Type0 = Sum4<block_q8_0, block_q8_0_x4, SignedDot>;
using Sum4Type1 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot>;
using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
+using Sum4TypeQ81 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot, false>;
template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
@@ -3103,6 +3127,12 @@ struct Q8_0_Dequantizer {
}
};
+struct Q8_0_1_Dequantizer {
+ inline __m256i dequant(const block_q8_0 * x) const {
+ return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));
+ }
+};
+
struct Q4_0_Dequantizer {
Dequantizer4bit b4;
const __m256i m8 = _mm256_set1_epi8(-8);
@@ -3111,6 +3141,13 @@ struct Q4_0_Dequantizer {
}
};
+struct Q4_0_1_Dequantizer {
+ Dequantizer4bit b4;
+ inline __m256i dequant(const block_q4_0 * x) const {
+ return b4.dequant(x->qs);
+ }
+};
+
struct IQ4_NL_Dequantizer {
Dequantizer4bit b4;
const __m256i values = load_values();
@@ -3231,11 +3268,21 @@ struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK8_0; }
};
+struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {
+ Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ81;
+ inline static int block_size() { return QK8_0; }
+};
struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK4_0; }
};
+struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> {
+ Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ81;
+ inline static int block_size() { return QK4_0; }
+};
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -3550,7 +3597,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
}
- else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
+ else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker> ||
+ std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker>) {
m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;
@@ -3815,8 +3863,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
- MulMat::set_functions<Q4_0_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_0;
+ //MulMat::set_functions<Q4_0_Unpacker>(mm);
+ //expected_typeB = GGML_TYPE_Q8_0;
+ MulMat::set_functions<Q4_0_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q4_1:
assert (ne00 % QK4_1 == 0);
@@ -3835,8 +3885,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
- MulMat::set_functions<Q8_0_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_0;
+ //MulMat::set_functions<Q8_0_Unpacker>(mm);
+ //expected_typeB = GGML_TYPE_Q8_0;
+ MulMat::set_functions<Q8_0_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_IQ4_NL:
assert (ne00 % QK4_NL == 0);