diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-12-08 15:27:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-08 15:27:13 +0100 |
commit | 43e65a672a98d931998559785b58f1e980e87f54 (patch) | |
tree | f0cefd9605710ff529d637e507af01327792d004 /ggml/src | |
parent | fc701cedd146152fb270482a7eef5aba23b20575 (diff) |
Faster IQ4_XS_R4 on Zen4 (#128)
* Faster iq4_xs_r4 on Zen4
The trick is to simply prepare the Q8 block sums for
blocks of 32 as floats. This brings PP-512 up to 254.6 t/s
from 224 t/s.
* Fix broken matrix x vector product on Zen4
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml.c | 11 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 22 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 44 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 4 |
4 files changed, 54 insertions, 27 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 69bb6d88..974e42b2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1124,6 +1124,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K16, .row_meta_size = 20, }, + [GGML_TYPE_Q8_K32] = { + .type_name = "q8_K32", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_K32, + .row_meta_size = 0, + }, [GGML_TYPE_BF16] = { .type_name = "bf16", .blck_size = 1, @@ -1292,7 +1300,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_xs_r4, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r4_ref, .vec_dot = vec_dot_iq4_xs_r4_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K, + .vec_dot_type = GGML_TYPE_Q8_K32, .nrows = 1, .row_meta_size = 0, }, @@ -15633,6 +15641,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K16: + case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 69c0b9a4..a9adedfc 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2917,15 +2917,16 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn template <int nrc_y> static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); auto m4 = _mm256_set1_epi8(0xf); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); -#endif auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); auto values = MM256_SET_M128I(values128, values128); - //auto values = load_iq4nl_values_256(); +#else + auto values = load_iq4nl_values_256(); +#endif int nbl = n / QK_K; using helper_t = union { __m256i vec; uint32_t val[8]; }; helper_t h; @@ -2969,7 +2970,7 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); float d8 = q8.scale(iy, ibl); - float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]); + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); #else @@ -2979,15 +2980,6 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))); auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); - //auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00))), - // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)))); - //auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa))), - // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)))); - //auto sumi = _mm256_add_epi32(sumi1, sumi2); - //float d8 = q8.scale(iy, ibl); - //float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]); - //acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); - //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); #endif } } @@ -3057,7 +3049,7 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); float d8 = q8.scale(iy, ibl); - float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]); + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]); } @@ -5074,7 +5066,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>; mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>; mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>; - expected_typeB = GGML_TYPE_Q8_K; + expected_typeB = GGML_TYPE_Q8_K32; break; case GGML_TYPE_Q4_0_R4: assert (ne00 % QK4_NL == 0); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 0ea608bc..66cf32bc 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2469,8 +2469,8 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq6_k); } - -void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { +template <bool is_K32> +void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; block_q8_K * y = (block_q8_K *)vy; @@ -2505,8 +2505,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { __m256i i1 = _mm256_cvtps_epi32(v1); __m256i i2 = _mm256_cvtps_epi32(v2); __m256i i3 = _mm256_cvtps_epi32(v3); - y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); - y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + if constexpr (is_K32) { + int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + auto bs = (float *)y[i].bsums; + bs[ib] = d*bsum; + } else { + y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + } i0 = _mm256_packs_epi32( i0, i1 ); i2 = _mm256_packs_epi32( i2, i3 ); i0 = _mm256_packs_epi16( i0, i2 ); @@ -2539,12 +2545,24 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); } - for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; - for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; + if constexpr (is_K32) { + auto bs = (float *)y[i].bsums; + float d = 1/iscale; + for (int j = 0; j < QK_K/32; ++j) { + int sum = 0; + for (int ii = 0; ii < 32; ++ii) { + sum += y[i].qs[j*32 + ii]; + } + bs[j] = d*sum; + } + } else { + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; } - y[i].bsums[j] = sum; } y[i].d = 1/iscale; x += QK_K; @@ -2553,6 +2571,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { } +void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<false>(x, vy, k); +} + +void quantize_row_q8_K32(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<true>(x, vy, k); +} + namespace { static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size, int n_per_row, const float * x, char * cy, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index f5bbd3dc..c370c8e8 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -61,8 +61,6 @@ size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); - void quantize_row_iq4_nl_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r4 * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq4_nl_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -111,9 +109,11 @@ void dequantize_row_iq2_bn_r4(const block_iq2_bn * GGML_RESTRICT x, float * GG size_t quantize_iq2_bn_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void vec_dot_iq2_bn_r4_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); #ifdef __cplusplus } |