diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-18 15:30:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-18 15:30:56 +0300 |
commit | c410cc72bbfcbdef9ce552b425ab7abbeb250dff (patch) | |
tree | a89b0a94dd7cdf99aef9ee3d0f1abbd48d7a3c3e | |
parent | dc96820ddb45c639ea4e149e4bbfcb0b67fbcc2b (diff) |
Much faster CPU prompt processing (part 3) (#534)
* Repack q4_0 and q8_0 to q8_0_R8
q8_0 is fine, but I observe a very significant PPL increase
for q4_0. Best guess: precision loss with the 32 bit <-> 16 bit
scale conversions.
* Change q8_2_x4 to store in16_t sums
With that q4_0 now works.
I need to check all quants that use q8_2_x4!
* q5_0 and use a dequntizing template
* q6_0
129 t/s -> 296 t/s. q6_0_r4 is at 244 t/s.
* iq4_nl
137 t/s -> 293 t/s. iq4_nl is at 251 t/s.
* q4_1: 135 t/s -> 262 t/s
* q5_1: 125 t/s -> 253 t/s
* iq3_xs
178 t/s -> 363 t/s. iq4_xs_r4 is at 275 t/s.
* q2_K
202 t/s -> 364 t/s. q2_k_r4 is at 247 t/s.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_gemm_kquants.cpp | 187 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 179 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.h | 2 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 29 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 12 |
5 files changed, 366 insertions, 43 deletions
diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 43eff43c..b46077f8 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -810,10 +810,11 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); _mm256_storeu_ps(d8 + 8*iy, dy); - auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))); - auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))); - auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16)); - accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]); + auto m4_1 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))); + auto m4_2 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))); + auto myi = MM256_SET_M128I(m4_2, m4_1); + auto my = _mm256_mul_ps(dy, _mm256_cvtepi32_ps(myi)); + accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]); } auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp)))); @@ -2017,6 +2018,91 @@ typedef struct { int8_t qs[8*QK8_1]; } block_q8_1_r8; +void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q2_K * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + float f_values[QK_K]; + uint32_t block[8]; + + __m256i xv[4]; + + auto ml = _mm256_set1_epi8(0x03); + auto sign_bit = _mm256_set1_ps(-0.0f); + auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto vd = _mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].d)); + auto vm = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].dmin)), _mm256_set1_ps(-1.f)); + auto block_max = _mm256_setzero_ps(); + for (int i128 = 0; i128 < 2; ++i128) { + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128); + xv[0] = _mm256_and_si256(bits, ml); + xv[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml); + xv[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml); + xv[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml); + for (int l = 0; l < 4; ++l) { + auto q1 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[l])); + auto q2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[l], 1)); + q1 = _mm256_mullo_epi16(q1, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf)); + q2 = _mm256_mullo_epi16(q2, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf)); + auto m1 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 0] >> 4)); + auto m2 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 1] >> 4)); + auto v0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q1))), vd, m1); + auto v1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q1, 1))), vd, m1); + auto v2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q2))), vd, m2); + auto v3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q2, 1))), vd, m2); + auto max = _mm256_max_ps(_mm256_max_ps(_mm256_andnot_ps(sign_bit, v0), _mm256_andnot_ps(sign_bit, v1)), + _mm256_max_ps(_mm256_andnot_ps(sign_bit, v2), _mm256_andnot_ps(sign_bit, v3))); + block_max = _mm256_max_ps(block_max, max); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 0, v0); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 8, v1); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 16, v2); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 24, v3); + } + } + auto max4 = _mm_max_ps(_mm256_extractf128_ps(block_max, 1), _mm256_castps256_ps128(block_max)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float d = _mm_cvtss_f32(max4/127.f); + auto id = _mm256_set1_ps(d != 0.0f ? 1/d : 0.0f); + y[i].d[k] = GGML_FP32_TO_FP16(d); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto v0 = _mm256_loadu_ps(f_values + 32*ib32 + 0); + auto v1 = _mm256_loadu_ps(f_values + 32*ib32 + 8); + auto v2 = _mm256_loadu_ps(f_values + 32*ib32 + 16); + auto v3 = _mm256_loadu_ps(f_values + 32*ib32 + 24); + auto i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v0, id), _MM_ROUND_NEAREST)); + auto i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v1, id), _MM_ROUND_NEAREST)); + auto i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v2, id), _MM_ROUND_NEAREST)); + auto i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v3, id), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, perm); + + _mm256_storeu_si256((__m256i *)block, i0); + auto q8 = (uint32_t *)y[i].qs + 64*ib32; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + } + } + } + y += nb; + } +} + void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -2429,6 +2515,97 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int } } +inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) { + auto max_i16 = _mm256_setzero_si256(); + __m256i qs[16]; + for (int ib32 = 0; ib32 < 8; ++ib32) { + qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32])); + qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1)); + qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0])); + qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1])); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0])); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1])); + } + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + bool needs_scaling = true; + float dnew = _mm_cvtss_f32(max4) * d0; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } + auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + if (needs_scaling) { + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0])); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1])); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]); + auto i0_l = _mm256_castsi256_si128(i0); + auto i0_h = _mm256_extracti128_si256(i0, 1); + _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h)); + _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h)); + } + auto qs = (uint32_t *)q8_k + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + return dnew; +} + +// TODO: move this to iqk_gemm_iquants +void iqk_convert_iq4_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq4_xs * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); + + int16_t ls[16]; + float dnew[8]; + uint32_t block[8]; + __m256i xv[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_xs *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + for (int ib32 = 0; ib32 < 8; ++ib32) { + ls[2*ib32+0] = ls[2*ib32+1] = (((x8[k][i].scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((x8[k][i].scales_h >> 2*ib32) & 3) << 4)) - 32; + auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs + ib32); + xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf)); + xv[ib32] = _mm256_shuffle_epi8(values, xv[ib32]); + } + dnew[k] = d * convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); + } + _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_loadu_ps(dnew), _MM_ROUND_NEAREST)); + } + y += nb; + } +} + } // namespace @@ -2516,10 +2693,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { switch (ggml_type(type)) { + case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break; default: return false; } return true; diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 17d2dad3..32ce78f2 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -172,27 +172,36 @@ struct ScaleHelperQ8_1 { } }; +inline __m256 convert_scales(const uint16_t * scales) { + auto aux_d = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales)), 16)); + auto aux_m = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(scales+4)))); + return _mm256_set_m128(_mm_mul_ps(aux_d, aux_m), aux_d); +} + struct ScaleHelperQ8_2 { template <typename Q> inline __m256 prepare4(const Q * y) { const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y; - auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d)); - return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16)); + return convert_scales((const uint16_t *)y4->d); } template <typename Q> inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm256_mul_ps(other_scales, prepare4<Q>(y)); } template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const { - return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m)); + float d = GGML_BF16_TO_FP32(y->d); + int16_t m = *(const int16_t *)&y->s; + return std::make_pair(d, d*m); } template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const { - ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; - return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + float d = GGML_BF16_TO_FP32(y->d); + int16_t m = *(const int16_t *)&y->s; + return std::make_pair(dm.first*d, dm.second*d*m); } std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const { - ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; - return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + ggml_bf16_t dy; dy.bits = y->d; int16_t s = *(const int16_t *)&y->s; + float d = GGML_BF16_TO_FP32(dy); + return std::make_pair(dm.first*d, dm.second*d*s); } }; @@ -542,6 +551,14 @@ struct IQ4_NL_Dequantizer { } }; +struct IQ4_NL0_Dequantizer { + Dequantizer4bit b4; + const __m256i values = load_iq4k_values_256(); + inline __m256i dequant(const block_iq4_nl * x) const { + return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); + } +}; + struct Q4_1_Dequantizer { Dequantizer4bit b4; inline __m256i dequant(const block_q4_1 * x) const { @@ -597,6 +614,12 @@ struct Q6_0_1_Dequantizer { return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(_mm256_srlv_epi64(h256, shift2), mh)); } }; +struct Q6_0_Dequantizer { + Q6_0_1_Dequantizer deq; + inline __m256i dequant(const block_q6_0 * x) const { + return _mm256_add_epi8(deq.dequant(x), _mm256_set1_epi8(-32)); + } +}; template <typename Q, typename Scales, typename Dequantizer> struct Q_Unpacker { @@ -728,8 +751,7 @@ static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const Data const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + _mm256_storeu_ps(d8+8*iy, convert_scales((const uint16_t *)q8.y[iy][ib4].d)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -893,7 +915,7 @@ static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const D auto acc1 = _mm256_setzero_ps(); auto acc2 = _mm256_setzero_ps(); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16)); + helper.vec = convert_scales((const uint16_t *)q8.y[0][ib4].d); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); @@ -929,7 +951,7 @@ static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const D d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); } for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + auto scales = convert_scales((const uint16_t *)q8.y[iy][ib4].d); _mm256_storeu_ps(d8 + 8*iy, scales); auto m4 = _mm256_extractf128_ps(scales, 1); auto m8 = _mm256_set_m128(m4, m4); @@ -1020,8 +1042,7 @@ static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + _mm256_storeu_ps(d8+8*iy, convert_scales((const uint16_t *)q8.y[iy][ib4].d)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -1108,7 +1129,7 @@ static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const D const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + auto scales = convert_scales((const uint16_t *)q8.y[iy][ib4].d); _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); } for (int k = 0; k < 4; ++k) { @@ -1189,7 +1210,7 @@ static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataIn const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16))); + _mm256_storeu_ps(d8+8*iy, convert_scales((const uint16_t *)q8.y[iy][ib4].d)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); @@ -1278,8 +1299,8 @@ static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const D const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); + auto scales = convert_scales((const uint16_t *)q8.y[iy][ib4].d); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq6[4*ib4+k]); @@ -1358,7 +1379,7 @@ static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataIn const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + auto scales = convert_scales((const uint16_t *)q8.y[iy][ib4].d); _mm256_storeu_ps(d8 + 8*iy, scales); } for (int k = 0; k < 4; ++k) { @@ -1453,8 +1474,7 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ix += 8) { const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16); - _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux)); + _mm256_storeu_ps(d8, convert_scales((const uint16_t *)q8.y[0][ib4].d)); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); @@ -1486,8 +1506,7 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + _mm256_storeu_ps(d8+8*iy, convert_scales((const uint16_t *)q8.y[iy][ib4].d)); } for (int k = 0; k < 4; ++k) { auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); @@ -1655,7 +1674,8 @@ static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16)); _mm_storeu_ps(d8 + 4*iy + 0, scales); - auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16)); + auto bsums4 = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4)))); + bsums4 = _mm_mul_ps(bsums4, scales); auto bsums = _mm256_set_m128(bsums4, bsums4); acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]); acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]); @@ -1690,6 +1710,105 @@ static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataIn } } +void iqk_convert_q80_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + static_assert(QK4_0 == QK8_0); + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK4_0; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_q8_0 * x8[8]; + + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k] = x8[k][i].d; + _mm256_storeu_si256((__m256i *)block, _mm256_loadu_si256((const __m256i *)x8[k][i].qs)); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + +template <typename Block, typename Dequantizer> +void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK8_0; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const Block * x8[8]; + + uint32_t block[8]; + + Dequantizer deq; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k] = x8[k][i].d; + _mm256_storeu_si256((__m256i *)block, deq.dequant(x8[k] + i)); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + +template <typename Block, typename Dequantizer> +void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK8_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK8_0; + + const Block * x8[8]; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + uint32_t block[8]; + + Dequantizer deq; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k+0] = x8[k][i].d; + y[i].d[k+8] = x8[k][i].m; + _mm256_storeu_si256((__m256i *)block, deq.dequant(x8[k]+i)); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) { if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> || std::is_same_v<Dequantizer, Q8_0_Unpacker>) { @@ -1713,6 +1832,20 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX } // namespace +bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { + switch (type) { + case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break; + default: return false; + } + return true; +} + bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { if (ne00%QK8_0 != 0) return false; diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h index a472d9bb..179e806a 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.h +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h @@ -11,4 +11,6 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step); +bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x); + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 81b5841d..6925e6a6 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -243,9 +243,11 @@ struct MulMat { case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ4_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; @@ -258,6 +260,13 @@ struct MulMat { case GGML_TYPE_IQ5_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else @@ -356,12 +365,12 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_BF16: //case GGML_TYPE_BF16_R16: // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); - //case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XS: //case GGML_TYPE_Q2_K_R4: //case GGML_TYPE_Q3_K_R4: //case GGML_TYPE_Q4_K_R4: @@ -403,19 +412,19 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x); - //case GGML_TYPE_Q4_0: - //case GGML_TYPE_Q4_1: - //case GGML_TYPE_Q5_0: - //case GGML_TYPE_Q5_1: - //case GGML_TYPE_Q6_0: - //case GGML_TYPE_Q8_0: - //case GGML_TYPE_IQ4_NL: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: //case GGML_TYPE_Q4_0_R8: //case GGML_TYPE_Q5_0_R4: //case GGML_TYPE_Q6_0_R4: //case GGML_TYPE_Q8_0_R8: //case GGML_TYPE_IQ4_NL_R4: - // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + return iqk_convert_legacy_quants_q8_r8(typeA, n, vx, bx, vy, nrc_x); case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: //case GGML_TYPE_IQ1_S_R4: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 9261d02e..abd4be61 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -875,14 +875,12 @@ void quantize_row_q8_1_x4_T(const float * x, Block * y, int64_t k) { y[i].d = GGML_FP32_TO_FP16(d); } } else { + auto t = GGML_FP32_TO_BF16(d); + d = ggml_bf16_to_fp32(t); if (i < nb4) { - auto t = GGML_FP32_TO_BF16(d); y4[i4].d[ir] = t.bits; - d = ggml_bf16_to_fp32(t); } else { - auto t = GGML_FP32_TO_BF16(d); y[i].d = t.bits; - d = ggml_bf16_to_fp32(t); } } const float id = d > 0 ? 1/d : 0.f; @@ -916,9 +914,11 @@ void quantize_row_q8_1_x4_T(const float * x, Block * y, int64_t k) { } } else { if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_BF16(d * isum).bits; + auto i16 = (int16_t *)y4[i4].d; + i16[ir+4] = isum; } else { - y[i].s = GGML_FP32_TO_BF16(d * isum).bits; + auto i16 = (int16_t *)&y[i].s; + i16[0] = isum; } } |