diff options
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 34 |
1 files changed, 22 insertions, 12 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 2eb53d1c..9261d02e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2831,6 +2831,8 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { const __m256 mul = _mm256_set1_ps( id ); xx = xb; int8_t * q8 = y[i].qs; + int block_sum_i32 = 0; + float block_sum_f32 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; @@ -2844,13 +2846,15 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { __m256i i1 = _mm256_cvtps_epi32(v1); __m256i i2 = _mm256_cvtps_epi32(v2); __m256i i3 = _mm256_cvtps_epi32(v3); - if constexpr (q8_type > 0) { + if constexpr (q8_type == 1) { int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); auto bs = (float *)y[i].bsums; bs[ib] = d*bsum; + block_sum_f32 += bs[ib]; } else { y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + block_sum_i32 += y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]; } i0 = _mm256_packs_epi32( i0, i1 ); i2 = _mm256_packs_epi32( i2, i3 ); @@ -2859,12 +2863,17 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { _mm256_storeu_si256((__m256i *)q8, i0); q8 += 32; } - if constexpr (q8_type == 2) { - auto bs = (float *)y[i].bsums; - float sum = 0; - for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; - bs[0] = sum; + if constexpr (q8_type == 1) { + y[i].sum = block_sum_f32; + } else { + y[i].sum = d*block_sum_i32; } + //if constexpr (q8_type == 2) { + // auto bs = (float *)y[i].bsums; + // float sum = 0; + // for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; + // bs[0] = sum; + //} } #else for (int i = 0; i < nb; i++) { @@ -2890,9 +2899,9 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); } - if constexpr (q8_type > 0) { + float d = 1/iscale; + if constexpr (q8_type == 1) { auto bs = (float *)y[i].bsums; - float d = 1/iscale; float sum = 0; for (int j = 0; j < QK_K/32; ++j) { int sum = 0; @@ -2902,19 +2911,20 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { bs[j] = d*sum; sum += bs[j]; } - if constexpr (q8_type == 2) { - bs[0] = sum; - } + y[i].sum = sum; } else { + int tot = 0; for (int j = 0; j < QK_K/16; ++j) { int sum = 0; for (int ii = 0; ii < 16; ++ii) { sum += y[i].qs[j*16 + ii]; } y[i].bsums[j] = sum; + tot += sum; } + y[i].sum = d*tot; } - y[i].d = 1/iscale; + y[i].d = d; x += QK_K; } #endif |