diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-12-14 09:24:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-14 09:24:30 +0100 |
commit | 20758edcae65213b2f575b6d23dfea67ad9dd0e0 (patch) | |
tree | f9f32d541da8bb945a45bbf473b9295496ec5c2b /ggml/src | |
parent | 12f962dd2494b743deb1c671974a591fdef1f003 (diff) |
Q8_K_R8: Fastest quantized matrix multiplications (#141)
* q8_k_r8: fastest matrix multiplication known to human kind
We get PP-512(LLaMA-3.1-8B) = 370 t/s on a Ryzen-7950X!
* q8_k_r8: AVX2
I was worried that we don't have enough vector registrers on
AVX2, but it looks like it handles it just fine. We get
PP-512(LLaMA-3.1-8B) = 354 t/s on a Ryzen-5975WX.
Slightly slower than the Zen4 version with double the threads,
but still a huge upgrade compared to Q8_0_R4.
* q8_k_r4: NEON
We get PP-512(LLaMA-3.1-8B) = 159.2 t/s.
Compare this to the 128 t/s we have fr Q8_0_R4.
* q8_k_r4: go to signed ints
Why?
* On AVX2 _mm256_maddubs_epi16() may overflow, so we need to
stay within the signed int range and use _mm256_sign_epi8.
Not yet tested on the AVX2 comp, vut expect major slowdown.
* It is almost 10% faster on ARM_NEON. Somehow the veorrq_u8()
needed tto convert from unsigned to signed seems to be extremely
slow on the M2-Max
* We only lose ~0.5% in oerformance on Zen4 (there the exclusive
or that we now use to convert fro signed to unsigned seems to be
much faster than on M2-Max)
* Shutup useless compiler warnings
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml-common.h | 6 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 31 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 139 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 103 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 7 |
6 files changed, 282 insertions, 5 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 2cacc711..d77ba12c 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -382,6 +382,12 @@ typedef struct { } block_q8_K128; static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding"); +typedef struct { + ggml_half d[8]; // delta + int8_t qs[8*QK_K]; // quants, stored as unsigned ints +} block_q8_k_r8; +static_assert(sizeof(block_q8_k_r8) == 8*sizeof(ggml_half) + 8*QK_K, "wrong q8_k_r8 block size/padding"); + // (Almost) "true" 2-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using // 2.0625 bpw because of the 16-bit scale for each block of 256. diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 64bd9459..f12c9fe8 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15208,6 +15208,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_Q5_K_R4: break; case GGML_TYPE_Q6_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; + case GGML_TYPE_Q8_K_R8: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 26ca7991..772c70c4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -979,6 +979,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q8_K_R8] = { + .type_name = "q8_k_r8", + .blck_size = QK_K, + .type_size = sizeof(block_q8_k_r8)/8, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_k_r8, + .from_float = quantize_row_q8_k_r8, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r8_ref, + .vec_dot = vec_dot_q8_k_r8_q8_k, + .vec_dot_type = GGML_TYPE_Q8_KR8, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", .blck_size = QK_K, @@ -1197,6 +1210,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K32, .row_meta_size = 0, }, + [GGML_TYPE_Q8_KR8] = { + .type_name = "q8_KR8", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_KR8, + .row_meta_size = 0, + }, [GGML_TYPE_BF16] = { .type_name = "bf16", .blck_size = 1, @@ -4105,6 +4126,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_K_R4: wtype = GGML_TYPE_Q5_K_R4; break; case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break; + case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; @@ -10641,6 +10663,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -11096,6 +11119,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -11248,6 +11272,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -14446,6 +14471,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -14838,6 +14864,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -15124,6 +15151,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -15737,6 +15765,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KR8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -22578,6 +22608,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_K_R4: result = quantize_q5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3f448275..75e5c3c1 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -92,6 +92,9 @@ struct DataInfo { inline void store(int ix, int iy, __m128 result) const { _mm_storeu_ps(dst_row(iy) + ix, result); } + inline void store(int ix, int iy, __m256 result) const { + _mm256_storeu_ps(dst_row(iy) + ix, result); + } #endif #ifdef __ARM_NEON inline void store(int ix, int iy, float32x4_t result) const { @@ -175,6 +178,7 @@ struct MulMat { case GGML_TYPE_IQ4_NL_R4: case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ2_BN_R4: return 4; + case GGML_TYPE_Q8_K_R8: return 8; default: return 1; } } @@ -3802,6 +3806,76 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn } } +// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) +template <int nrc_y> +static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d)); + for (int ib = 0; ib < QK_K/16; ++ib) { + qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0); + qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1); + qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2); + qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3); +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0), _mm256_set1_epi8(-128)); + qx[1] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1), _mm256_set1_epi8(-128)); + qx[2] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2), _mm256_set1_epi8(-128)); + qx[3] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3), _mm256_set1_epi8(-128)); +#else + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))); + auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]))); + auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))); + auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4)); +#endif + } + } +#ifdef HAVE_FANCY_SIMD + auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f)); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); +#ifdef HAVE_FANCY_SIMD + auto bsums = (const float *)q8.y[iy][ibl].bsums; + acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]); +#endif + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template <int nrc_y> static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -5976,6 +6050,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>; expected_typeB = GGML_TYPE_Q8_K; break; + case GGML_TYPE_Q8_K_R8: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q8_k_r8_q8_k<1>; + mm.funcs[1] = mul_mat_q8_k_r8_q8_k<2>; + mm.funcs[2] = mul_mat_q8_k_r8_q8_k<3>; + mm.funcs[3] = mul_mat_q8_k_r8_q8_k<4>; + mm.funcs[4] = mul_mat_q8_k_r8_q8_k<5>; + mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>; + mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>; + mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_KR8; + break; case GGML_TYPE_IQ4_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>; @@ -9158,6 +9244,55 @@ void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf } } +template <int nrc_y> +void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8<nrc_y, block_q8_K> q8(info); + int nbl = n / QK_K; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0)); + auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4)); + int32x4_t isum[2*nrc_y] = {}; + for (int ib = 0; ib < QK_K/16; ++ib) { + auto q1 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 0); + auto q2 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 64); + for (int k = 0; k < 4; ++k) { + q1.val[k] = veorq_u8(q1.val[k], vdupq_n_u8(0x80)); + q2.val[k] = veorq_u8(q2.val[k], vdupq_n_u8(0x80)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + const float * bsum = (const float *)q8.y[iy][ibl].bsums; + auto m8 = vdupq_n_f32(-128.f*bsum[0]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0])); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1])); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<1, block_q8_0_x4> q8(info); @@ -9575,6 +9710,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_Q8_K_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k); + expected_Btype = GGML_TYPE_Q8_KR8; + break; case GGML_TYPE_IQ4_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 438a277e..de8c0d99 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2469,7 +2469,7 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq6_k); } -template <bool is_K32> +template <int q8_type> void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2505,7 +2505,7 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { __m256i i1 = _mm256_cvtps_epi32(v1); __m256i i2 = _mm256_cvtps_epi32(v2); __m256i i3 = _mm256_cvtps_epi32(v3); - if constexpr (is_K32) { + if constexpr (q8_type > 0) { int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); auto bs = (float *)y[i].bsums; bs[ib] = d*bsum; @@ -2520,6 +2520,12 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { _mm256_storeu_si256((__m256i *)q8, i0); q8 += 32; } + if constexpr (q8_type == 2) { + auto bs = (float *)y[i].bsums; + float sum = 0; + for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; + bs[0] = sum; + } } #else for (int i = 0; i < nb; i++) { @@ -2545,15 +2551,20 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); } - if constexpr (is_K32) { + if constexpr (q8_type > 0) { auto bs = (float *)y[i].bsums; float d = 1/iscale; + float sum = 0; for (int j = 0; j < QK_K/32; ++j) { int sum = 0; for (int ii = 0; ii < 32; ++ii) { sum += y[i].qs[j*32 + ii]; } bs[j] = d*sum; + sum += bs[j]; + } + if constexpr (q8_type == 2) { + bs[0] = sum; } } else { for (int j = 0; j < QK_K/16; ++j) { @@ -2572,11 +2583,15 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { } void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { - iqk_quantize_row_q8_K_T<false>(x, vy, k); + iqk_quantize_row_q8_K_T<0>(x, vy, k); } void quantize_row_q8_K32(const float * x, void * vy, int64_t k) { - iqk_quantize_row_q8_K_T<true>(x, vy, k); + iqk_quantize_row_q8_K_T<1>(x, vy, k); +} + +void quantize_row_q8_KR8(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<2>(x, vy, k); } namespace { @@ -4666,3 +4681,81 @@ void vec_dot_iq4_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +// +// ========================================= q8_k_r8 +// + +void quantize_row_q8_k_r8_ref(const float * x, block_q8_k_r8 * y, int64_t k) { + quantize_q8_k_r8(x, (void *)y, 8, k/8, nullptr); +} + +void quantize_row_q8_k_r8(const float * x, void * y, int64_t k) { + quantize_q8_k_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8_k_r8 * y) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q8_K * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 8; ++k) { + y[ibl].d[k] = GGML_FP32_TO_FP16(x8[k][ibl].d); + for (int ib = 0; ib < QK_K/4; ++ib) { + for (int i = 0; i < 4; ++i) y[ibl].qs[32*ib + 4*k + i] = x8[k][ibl].qs[4*ib+i]; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q8_k_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_K, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_K_R8, n_per_row); + std::vector<char> qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_row_q8_K32(src, (void *)qtmp.data(), 8*n_per_row); + repack_q8_k(8, n_per_row, (const block_q8_K *)qtmp.data(), (block_q8_k_r8 *)qcur); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_k_r8(const block_q8_k_r8 * x, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 8; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/4; ++ib) { + for (int i = 0; i < 4; ++i) { + y8[k][QK_K*ibl+4*ib+i] = d * x[ibl].qs[32*ib+4*k+i]; + } + } + } + } +} + +void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_K_R8, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index c5702d73..753bbdb5 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -145,11 +145,18 @@ size_t quantize_iq4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT d void dequantize_row_iq4_k_r4(const block_iq4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); #ifdef __cplusplus } |