diff options
-rw-r--r-- | iqk-quantize.cpp | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index f5e357a1..b2100ac4 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -182,6 +182,15 @@ void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) { } } +namespace { +inline int8_t iq1bn_dequant(uint8_t q, int i) { + uint8_t v = IQ1BNQuantizer::k_mult[i]*q; + //int8_t vs = (v + (v << 1)) >> 8; + int8_t vs = 3*v >> 8; + return vs - 1; +} +} + void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { GGML_UNUSED(bs); @@ -204,29 +213,29 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si int sumi[8] = {}; int8_t q1[16]; - for (int i = 0; i < nblock; ++i) { - auto ql = x[i].ql; - auto extra = x[i].extra; - for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { - for (int k = 0; k < 3; ++k) { - uint8_t q = *ql++; - for (int j = 0; j < 5; ++j) { - uint8_t v = IQ1BNQuantizer::k_mult[j]*q; - int8_t vs = 3*v >> 8; - q1[5*k+j] = vs - 1; + for (int ii = 0; ii < nblock; ii += 32) { + int16_t sum16[8] = {}; + int nb = std::min(ii + 32, nblock); + for (int i = ii; i < nb; ++i) { + auto ql = x[i].ql; + auto extra = x[i].extra; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + uint8_t q = *ql++; + for (int j = 0; j < 5; ++j) q1[5*k+j] = iq1bn_dequant(q, j); } + q1[15] = iq1bn_dequant(extra, i16); + // We collect 8 q8 values per block into each element of sum16 + // => 32 x 8 = 256 values in each loop over i, so this cannot overflow the int16_t range + // (q8 is in -127...127, and hence the sum is in -32512...32512 + for (int j = 0; j < 8; ++j) sum16[j] += q8[2*j+0]*q1[2*j+0] + q8[2*j+1]*q1[2*j+1]; + q8 += 16; } - uint8_t v = IQ1BNQuantizer::k_mult[i16]*extra; - int8_t vs = 3*v >> 8; - q1[15] = vs - 1; - for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[j]; - q8 += 8; - for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[8+j]; - q8 += 8; } + for (int j = 0; j < 8; ++j) sumi[j] += sum16[j]; } - *s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]); + *s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]); } void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { |