diff options
-rw-r--r-- | ggml/src/ggml-quants.c | 5 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 100 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 2 |
3 files changed, 107 insertions, 0 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 41362dee..981fb54b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -12,6 +12,7 @@ #include "ggml-impl.h" #if GGML_USE_IQK_MULMAT #include "iqk/iqk_mul_mat.h" +#include "iqk/iqk_quantize.h" #endif @@ -3770,7 +3771,11 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { +#ifdef GGML_USE_IQK_MULMAT + iqk_quantize_row_q8_K(x, y, k); +#else quantize_row_q8_K_ref(x, y, k); +#endif } //===================================== Dot ptoducts ================================= diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 730de8c9..b7dbe685 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1982,3 +1982,103 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t *s = sumf; } +#ifdef __AVX2__ +namespace { +inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +inline float hmax_f32_8(__m256 x) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); + return _mm_cvtss_f32(max4); +} +} +#endif + +void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + block_q8_K * y = (block_q8_K *)vy; +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + for (int i = 0; i < nb; i++) { + const float * xb = x + i*QK_K; + __m256 maxAbs = _mm256_setzero_ps(); + const float * xx = xb; + for (int ib = 0; ib < QK_K/8; ++ib) { + const __m256 v = _mm256_loadu_ps(xx); xx += 8; + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + xx = xb; + int8_t * q8 = y[i].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + } +#else + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +#endif + +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 3c5d27a4..d7a0748f 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -49,6 +49,8 @@ size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); + #ifdef __cplusplus } #endif |