diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-19 15:33:27 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-19 15:33:27 +0300 |
commit | a73702d93b1007b2f528432c3db20c7aa5206352 (patch) | |
tree | 9837a54a824e497a0bddefbb739a54a79df7441d /ggml/src/iqk/iqk_quantize.cpp | |
parent | 5652100afcc423cf6342778cde372ca6aa54a79b (diff) |
AVX2 quantization for Q8_K (#22)
It has been there for a while, but forgot to add here.
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 730de8c9..b7dbe685 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1982,3 +1982,103 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t *s = sumf; } +#ifdef __AVX2__ +namespace { +inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +inline float hmax_f32_8(__m256 x) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); + return _mm_cvtss_f32(max4); +} +} +#endif + +void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + block_q8_K * y = (block_q8_K *)vy; +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + for (int i = 0; i < nb; i++) { + const float * xb = x + i*QK_K; + __m256 maxAbs = _mm256_setzero_ps(); + const float * xx = xb; + for (int ib = 0; ib < QK_K/8; ++ib) { + const __m256 v = _mm256_loadu_ps(xx); xx += 8; + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + xx = xb; + int8_t * q8 = y[i].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + } +#else + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +#endif + +} |