diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-13 13:34:30 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-13 13:34:30 +0300 |
commit | 910a13409463f7aedb0a92be013a1b9bb50f4859 (patch) | |
tree | 16e13e1fd3010549877408a0a62706b2bc5d5f0c /ggml/src/iqk/iqk_quantize.cpp | |
parent | c15de3654e0002537c8052fd6d52d879e778e88c (diff) |
IQ2_KS: 2.1875 bpw non-linear quantization (#85)
* Experimenting
* iq2k: Try make_qx_quants for the scale
Slightly better for LLaMA-3.1, Gemma-2, slightly worse for
Qwen2.5
* iq2k with make_qx_quants: adjust scale
* iq2ks: basics
* iq2_ks: CUDA works
* iq2_ks: WIP
* iq2_ks: WIP
* iq2_ks: Zen4
* iq2_ks: AVX2
* iq2_ks: scalar dot product
* iq2_ks: ARM_NEON
* iq2_ks: Metal
* iq2_ks: faster Metal
LLaMA-3.1-8B:
PP-512 = 475.22 ± 0.37 t/s
TG-128 = 45.32 ± 0.03 t/s
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 417 |
1 files changed, 406 insertions, 11 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 430b629f..984801be 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -30,6 +30,50 @@ inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } +float make_qx_quants(int n, int nmax, const float * x, int8_t * L, const float * qw) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) L[i] = 0; + return 0.f; + } + float iscale = -nmax / max; + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = std::max(-nmax, std::min(nmax-1, l)); + L[i] = l + nmax; + sumlx += qw[i]*x[i]*l; + suml2 += qw[i]*l*l; + } + float scale = suml2 ? sumlx/suml2 : 0.0f; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) continue; + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = std::max(-nmax, std::min(nmax-1, l)); + sumlx += qw[i]*x[i]*l; + suml2 += qw[i]*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + std::max(-nmax, std::min(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + struct IQ1BNQuantizer { int8_t L[QK_IQ1BN]; void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix); @@ -507,6 +551,8 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl float scales[QK_K/kBlockSize]; float weight[kBlockSize]; float sumx[kBlockSize+1], sumw[kBlockSize+1]; + float sw[QK_K/kBlockSize]; + int8_t Ls[QK_K/kBlockSize]; std::array<std::pair<float,int>, kBlockSize> pairs; @@ -524,7 +570,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl uint16_t extra = 0; - float max_abs_scale = 0; + float max_abs_scale = 0, max_scale = 0; for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { const float * xb = xbl + kBlockSize*ib; @@ -534,7 +580,11 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl } else { for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; } - for (int j = 0; j < kBlockSize; ++j) pairs[j] = {xb[j], j}; + sw[ib] = 0; + for (int j = 0; j < kBlockSize; ++j) { + sw[ib] += weight[j]; + pairs[j] = {xb[j], j}; + } std::sort(pairs.begin(), pairs.end()); sumx[0] = sumw[0] = 0; for (int j = 0; j < kBlockSize; ++j) { @@ -583,21 +633,25 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl if (is_shifted) extra |= (1 << ib); float abs_scale = fabsf(scales[ib]); - max_abs_scale = MAX(max_abs_scale, abs_scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scales[ib]; + } } if (!max_abs_scale) continue; + float d = make_qx_quants(QK_K/kBlockSize, 8, scales, Ls, sw); + if (!d) continue; - float d = max_abs_scale/15; + //float d = -max_scale/8; y[ibl].extra = extra; float id = 1/d; float sumqx = 0, sumq2 = 0; for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { - int ls = nearest_int(0.5f*(id*scales[ib]+15)); - ls = MAX(0, MIN(15, ls)); - y[ibl].scales[ib/2] |= (ls << 4*(ib%2)); - ls = 2*ls - 15; + int ls = nearest_int(id*scales[ib]); + ls = std::max(-8, std::min(7, ls)); + y[ibl].scales[ib/2] |= ((ls + 8) << 4*(ib%2)); float dl = d * ls; if (dl) { const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values; @@ -623,7 +677,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl } } } - y[ibl].d = GGML_FP32_TO_FP16(1.025f*(sumq2 > 0 ? sumqx/sumq2 : d)); + y[ibl].d = GGML_FP32_TO_FP16(1.030f*(sumq2 > 0 ? sumqx/sumq2 : d)); } } @@ -665,8 +719,8 @@ void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RES int shift = 0; for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - float dl1 = d * (2*(x[i].scales[ib32] & 0xf) - 15); - float dl2 = d * (2*(x[i].scales[ib32] >> 4) - 15); + float dl1 = d * ((x[i].scales[ib32] & 0xf) - 8); + float dl2 = d * ((x[i].scales[ib32] >> 4) - 8); const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values; const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values; extra >>= 2; @@ -701,6 +755,347 @@ void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * } +namespace { +void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_sw, int8_t * all_Ls) { + + constexpr int kBlockSize = 32; + constexpr int kMax_i1 = 3*kBlockSize/4; + constexpr int kMin_i3 = kBlockSize/4; + //constexpr int kNtry = 5; + //constexpr float kStep = 1.f; + + ggml_half * dptr = (ggml_half *)vy; + *dptr = GGML_FP32_TO_FP16(0.f); + + block_iq2_ks * y = (block_iq2_ks *)(dptr + 1); + + float weight[kBlockSize]; + float sumx[kBlockSize+1], sumw[kBlockSize+1]; + + std::array<std::pair<float,int>, kBlockSize> pairs; + + float val [4] = {float(iq2nl_values[0]), float(iq2nl_values[1]), float(iq2nl_values[2]), float(iq2nl_values[3])}; + float sval[4] = {float(iq2nl_values[4]), float(iq2nl_values[5]), float(iq2nl_values[6]), float(iq2nl_values[7])}; + + const int8_t * shifted_values = iq2nl_values + 4; + + const int nblock = n_per_row/QK_K; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_ks)); + + auto scales = all_scales + ibl*(QK_K/kBlockSize); + auto sw = all_sw + ibl*(QK_K/kBlockSize); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + + uint16_t extra = 0; + + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + sw[ib] = 0; + for (int j = 0; j < kBlockSize; ++j) { + sw[ib] += weight[j]; + pairs[j] = {xb[j], j}; + } + //float amax = 0, max = 0; + //for (int j = 0; j < kBlockSize; ++j) { + // float ax = fabsf(xb[j]); + // if (ax > amax) { + // amax = ax; max = xb[j]; + // } + //} + //if (!amax) { + // scales[ib] = 0; + // continue; + //} + //float d = kNtry > 0 ? -max/iq2nl_values[0] : max/iq2nl_values[0]; + //float id = 1/d; + //float sumqx_p = 0, sumq2_p = 0; + //float sumqx_m = 0, sumq2_m = 0; + //for (int j = 0; j < kBlockSize; ++j) { + // float w = weight[j]; + // float al = id*xb[j]; + // int l = best_index_iq2nl(iq2nl_values, al); + // float q = iq2nl_values[l]; + // sumqx_p += w*q*xb[j]; + // sumq2_p += w*q*q; + // l = best_index_iq2nl(iq2nl_values, -al); + // q = iq2nl_values[l]; + // sumqx_m += w*q*xb[j]; + // sumq2_m += w*q*q; + //} + //d = sumqx_p/sumq2_p; + //float best = d*sumqx_p; + //if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + // d = sumqx_m/sumq2_m; best = d*sumqx_m; + //} + //bool is_shifted = false; + //for (int itry = -kNtry; itry <= kNtry; ++itry) { + // id = (kStep*itry + iq2nl_values[0])/max; + // sumqx_p = sumq2_p = 0; + // sumqx_m = sumq2_m = 0; + // for (int j = 0; j < kBlockSize; ++j) { + // float w = weight[j]; + // float al = id*xb[j]; + // int l = best_index_iq2nl(iq2nl_values, al); + // float q = iq2nl_values[l]; + // sumqx_p += w*q*xb[j]; + // sumq2_p += w*q*q; + // l = best_index_iq2nl(iq2nl_values, -al); + // q = iq2nl_values[l]; + // sumqx_m += w*q*xb[j]; + // sumq2_m += w*q*q; + // } + // if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + // } + // if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + // } + // id = (kStep*itry + shifted_values[0])/max; + // sumqx_p = sumq2_p = 0; + // sumqx_m = sumq2_m = 0; + // for (int j = 0; j < kBlockSize; ++j) { + // float w = weight[j]; + // float al = id*xb[j]; + // int l = best_index_iq2nl(shifted_values, al); + // float q = shifted_values[l]; + // sumqx_p += w*q*xb[j]; + // sumq2_p += w*q*q; + // l = best_index_iq2nl(shifted_values, -al); + // q = shifted_values[l]; + // sumqx_m += w*q*xb[j]; + // sumq2_m += w*q*q; + // } + // if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + // } + // if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + // } + //} + std::sort(pairs.begin(), pairs.end()); + sumx[0] = sumw[0] = 0; + for (int j = 0; j < kBlockSize; ++j) { + int jj = pairs[j].second; + sumw[j+1] = sumw[j] + weight[jj]; + sumx[j+1] = sumx[j] + weight[jj]*xb[jj]; + } + float best = 0, d = 0; + bool is_shifted = false; + float sumqx, sumq2; + for (int i1 = 0; i1 < kMax_i1; ++i1) { + for (int i2 = i1; i2 < kBlockSize; ++i2) { + for (int i3 = std::max(i2, kMin_i3); i3 < kBlockSize; ++i3) { + sumqx = (sumx[i1] - sumx[ 0])*val[0] + (sumx[i2] - sumx[i1])*val[1] + + (sumx[i3] - sumx[i2])*val[2] + (sumx[kBlockSize] - sumx[i3])*val[3]; + sumq2 = (sumw[i1] - sumw[ 0])*val[0]*val[0] + (sumw[i2] - sumw[i1])*val[1]*val[1] + + (sumw[i3] - sumw[i2])*val[2]*val[2] + (sumw[kBlockSize] - sumw[i3])*val[3]*val[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*sval[0] + (sumx[i2] - sumx[i1])*sval[1] + + (sumx[i3] - sumx[i2])*sval[2] + (sumx[kBlockSize] - sumx[i3])*sval[3]; + sumq2 = (sumw[i1] - sumw[ 0])*sval[0]*sval[0] + (sumw[i2] - sumw[i1])*sval[1]*sval[1] + + (sumw[i3] - sumw[i2])*sval[2]*sval[2] + (sumw[kBlockSize] - sumw[i3])*sval[3]*sval[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + sumqx = (sumx[i1] - sumx[ 0])*val[3] + (sumx[i2 ] - sumx[i1])*val[2] + + (sumx[i3] - sumx[i2])*val[1] + (sumx[kBlockSize] - sumx[i3])*val[0]; + sumq2 = (sumw[i1] - sumw[ 0])*val[3]*val[3] + (sumw[i2 ] - sumw[i1])*val[2]*val[2] + + (sumw[i3] - sumw[i2])*val[1]*val[1] + (sumw[kBlockSize] - sumw[i3])*val[0]*val[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*sval[3] + (sumx[i2 ] - sumx[i1])*sval[2] + + (sumx[i3] - sumx[i2])*sval[1] + (sumx[kBlockSize] - sumx[i3])*sval[0]; + sumq2 = (sumw[i1] - sumw[ 0])*sval[3]*sval[3] + (sumw[i2 ] - sumw[i1])*sval[2]*sval[2] + + (sumw[i3] - sumw[i2])*sval[1]*sval[1] + (sumw[kBlockSize] - sumw[i3])*sval[0]*sval[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + } + } + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + } + y[ibl].extra = extra; + + } + + float d = make_qx_quants(nblock*(QK_K/kBlockSize), 16, all_scales, all_Ls, all_sw); + + if (!d) return; + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + auto Ls = all_Ls + ibl*(QK_K/kBlockSize); + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + int ls = Ls[ib]; + y[ibl].scales[ib/2] |= ((ls & 0xf) << 4*(ib%2)); + y[ibl].extra |= ((ls >> 4) << (8 + ib)); + ls -= 16; + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values; + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + uint8_t * qs = y[ibl].qs + 32*(ib/4); + for (int j = 0; j < 32; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq2nl(block_values, al); + qs[j] |= (ibest << 2*(ib%4)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + } + *dptr = GGML_FP32_TO_FP16(1.030f*(sumq2 > 0 ? sumqx/sumq2 : d)); +} +} + +void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_ks(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_ks * y = (block_iq2_ks *)vy; + quantize_row_iq2_ks_ref(x, y, k); +} + +size_t quantize_iq2_ks(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + constexpr int kBlockSize = 32; + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ2_KS, n_per_row); + int nblock = n_per_row/QK_K; + std::vector<float> all_scales(nblock*(QK_K/kBlockSize)), all_sw(nblock*(QK_K/kBlockSize)); + std::vector<int8_t> all_Ls(nblock*(QK_K/kBlockSize)); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq2_ks_impl(src, (void *)qrow, n_per_row, imatrix, all_scales.data(), all_sw.data(), all_Ls.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const ggml_half * dptr = (const ggml_half *)x; + const float d = GGML_FP16_TO_FP32(*dptr); + x = (const block_iq2_ks *)(dptr + 1); + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + float dl1 = d * (((x[i].scales[ib64] & 0xf) | ((extra >> 4) & 0x10)) - 16); + float dl2 = d * (((x[i].scales[ib64] >> 4) | ((extra >> 5) & 0x10)) - 16); + const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values; + extra >>= 2; + for (int j = 0; j < 32; ++j) { + y[j+ 0] = dl1 * values1[(qs[j] >> (shift+0)) & 3]; + y[j+32] = dl2 * values2[(qs[j] >> (shift+2)) & 3]; + } + y += 64; + shift += 4; + if (shift == 8) { qs += 32; shift = 0; } + } + + } + +} + +void vec_dot_iq2_ks_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_KS, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + + const ggml_half * dptr = (const ggml_half *)vx; + const float d = GGML_FP16_TO_FP32(*dptr); + const block_iq2_ks * x = (const block_iq2_ks *)(dptr + 1); + const block_q8_K * y = (const block_q8_K *)vy; + + const int nb = n / QK_K; + float sumf = 0; + for (int i = 0; i < nb; i++) { + const uint8_t * qs = x[i].qs; + const int8_t * q8 = y[i].qs; + uint16_t extra = x[i].extra; + int sumi = 0; + for (int ib128 = 0; ib128 < QK_K/128; ++ib128) { + int d1 = (((x[i].scales[2*ib128+0] & 0xf) | ((extra >> 4) & 0x10)) - 16); + int d2 = (((x[i].scales[2*ib128+0] >> 4) | ((extra >> 5) & 0x10)) - 16); + int d3 = (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16); + int d4 = (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16); + const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values3 = extra & 4 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values4 = extra & 8 ? iq2nl_values + 4 : iq2nl_values; + extra >>= 4; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 32; ++j) { + sumi1 += q8[j+ 0] * values1[(qs[j] >> 0) & 3]; + sumi2 += q8[j+32] * values2[(qs[j] >> 2) & 3]; + sumi3 += q8[j+64] * values3[(qs[j] >> 4) & 3]; + sumi4 += q8[j+96] * values4[(qs[j] >> 6) & 3]; + } + sumi += d1*sumi1 + d2*sumi2 + d3*sumi3 + d4*sumi4; + q8 += 128; + qs += 32; + } + sumf += y[i].d * sumi; + } + + *s = d * sumf; + +} + // // ============================================== iq3_k // |