diff options
Diffstat (limited to 'ggml/src/iqk')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 62 |
1 files changed, 49 insertions, 13 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index cac1fd49..e2cea7df 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1555,12 +1555,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) { static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { - const int ntry = 5; + constexpr int ntry = 3; block_iq3_k * y = (block_iq3_k *)vy; float scales[QK_K/16]; float weight[16]; + uint8_t L[16]; const int8_t * shifted_values = iq3nl_values + 8; @@ -1620,7 +1621,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c } bool is_shifted = false; for (int itry = -ntry; itry <= ntry; ++itry) { - id = (itry + iq3nl_values[0])/max; + id = (2*itry + iq3nl_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1641,7 +1642,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; } - id = (itry + shifted_values[0])/max; + id = (2*itry + shifted_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1663,20 +1664,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; } } - if (d) { - const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; - float sumqx = 0, sumq2 = 0; - id = 1/d; + if (!d) { + scales[ib] = 0; continue; + } + + const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(block_values, al); + L[j] = l; + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + + float best_d = d; + for (int iter = 0; iter < 128; ++iter) { + float gmax = 0; + int best_j = -1, dir = 0; for (int j = 0; j < 16; ++j) { float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(block_values, al); - float q = block_values[l]; - sumqx += w*q*xb[j]; - sumq2 += w*q*q; + float g = d * w * (xb[j] - d*block_values[L[j]]); + if (g > 0 && L[j] < 7) { + if (g > gmax) { + gmax = g; best_j = j; dir = 1; + } + } + else if (g < 0 && L[j] > 0) { + if (-g > gmax) { + gmax = -g; best_j = j; dir = -1; + } + } } - if (sumq2 > 0) d = sumqx/sumq2; + if (best_j < 0) break; + + float w = weight[best_j]; + sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]); + sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]); + L[best_j] += dir; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + best_d = sumqx/sumq2; best = best_d*sumqx; + } + else if (iter > 8) break; + } + scales[ib] = d; if (is_shifted) extra |= (1 << ib); |