diff options
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 32 |
1 files changed, 13 insertions, 19 deletions
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 32ce78f2..76f8db09 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -188,20 +188,14 @@ struct ScaleHelperQ8_2 { inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm256_mul_ps(other_scales, prepare4<Q>(y)); } - template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const { - float d = GGML_BF16_TO_FP32(y->d); + template <typename Q> static inline std::pair<float, float> prepare1(const Q * y) { + float d = GGML_BF16_TO_FP32(ggml_bf16_t{y->d}); int16_t m = *(const int16_t *)&y->s; return std::make_pair(d, d*m); } - template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const { - float d = GGML_BF16_TO_FP32(y->d); - int16_t m = *(const int16_t *)&y->s; - return std::make_pair(dm.first*d, dm.second*d*m); - } - std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const { - ggml_bf16_t dy; dy.bits = y->d; int16_t s = *(const int16_t *)&y->s; - float d = GGML_BF16_TO_FP32(dy); - return std::make_pair(dm.first*d, dm.second*d*s); + static inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) { + auto d = prepare1(y); + return std::make_pair(dm.first*d.first, dm.second*d.second); } }; @@ -1484,14 +1478,14 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn } } if (4*(nb/4) < nb) { - auto qy = (const block_q8_1 *)q8.y[0]; + auto qy = (const block_q8_2 *)q8.y[0]; for (int ib = 4*(nb/4); ib < nb; ++ib) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); - ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8)); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(m8), acc[1]); } } info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); @@ -1535,12 +1529,12 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); } for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; + auto qy = (const block_q8_2 *)q8.y[iy]; auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); - ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib); + auto dy = _mm512_set1_ps(d8); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { |