diff options
Diffstat (limited to 'iqk-quantize.cpp')
-rw-r--r-- | iqk-quantize.cpp | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index b8d91bcf..03d73ff0 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -118,17 +118,29 @@ uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const fl void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) { + static const int k_nb[8] = {1, 3, 9, 27, 81, 243, 729, 2187}; (void)imatrix; const int nblock = n_per_row/QK_IQ1BN; - const auto& iq1bn = get_iq1bn_data(); - for (int ib = 0; ib < nblock; ++ib) { std::memset(&y[ib], 0, sizeof(block_iq1_bn)); - auto xb = src + QK_IQ1BN*ib; - y[ib].extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh); + auto xb = src + ib*QK_IQ1BN; + for (int i = 0; i < QK_IQ1BN/8; ++i) { + int idx = 0; + for (int j = 0; j < 8; ++j) { + float v = xb[8*i + j]; + int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + idx += k_nb[j]*q; + } + idx = (8192*idx + 6560)/6561; + y[ib].ql[i] = idx & 255; + y[ib].qh[i%4] |= ((idx >> 8) & 0xf) << 4*(i/4); + y[ib].extra |= (idx >> 12) << i; + + } } + } void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) { @@ -182,21 +194,18 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; - uint32_t aux32[2]; - const int8_t * aux8 = (const int8_t *)aux32; + static const int k_mult[8] = {17496, 5832, 1944, 648, 216, 72, 24, 8}; + for (int i = 0; i < nblock; ++i) { uint8_t extra = x[i].extra; auto qh = x[i].qh; auto ql = x[i].ql; for (int k = 0; k < QK_IQ1BN/8; ++k) { - uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); - uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx]; - aux32[0] = val | (val << 14); - aux32[1] = (aux32[0] >> 4) & 0x03030303; - aux32[0] &= 0x03030303; - for (int j = 0; j < 8; ++j) y[j] = aux8[j] - 1; - y += 8; - extra >>= 1; + uint16_t idx = ql[k] | ((qh[k%4] << (8 - 4*(k/4))) & 0x0f00) | ((extra << (12 - k)) & 4096); + for (int j = 0; j < 8; ++j) { + int v = (idx*k_mult[j] & 0xffff)*3 >> 16; + *y++ = v - 1; + } } } } |