summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_quantize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp62
1 files changed, 49 insertions, 13 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index cac1fd49..e2cea7df 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -1555,12 +1555,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) {
static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
- const int ntry = 5;
+ constexpr int ntry = 3;
block_iq3_k * y = (block_iq3_k *)vy;
float scales[QK_K/16];
float weight[16];
+ uint8_t L[16];
const int8_t * shifted_values = iq3nl_values + 8;
@@ -1620,7 +1621,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
}
bool is_shifted = false;
for (int itry = -ntry; itry <= ntry; ++itry) {
- id = (itry + iq3nl_values[0])/max;
+ id = (2*itry + iq3nl_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
@@ -1641,7 +1642,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
}
- id = (itry + shifted_values[0])/max;
+ id = (2*itry + shifted_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
@@ -1663,20 +1664,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
}
}
- if (d) {
- const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
- float sumqx = 0, sumq2 = 0;
- id = 1/d;
+ if (!d) {
+ scales[ib] = 0; continue;
+ }
+
+ const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
+ float sumqx = 0, sumq2 = 0;
+ id = 1/d;
+ for (int j = 0; j < 16; ++j) {
+ float w = weight[j];
+ float al = id*xb[j];
+ int l = best_index_iq3nl(block_values, al);
+ L[j] = l;
+ float q = block_values[l];
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) d = sumqx/sumq2;
+
+ float best_d = d;
+ for (int iter = 0; iter < 128; ++iter) {
+ float gmax = 0;
+ int best_j = -1, dir = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
- float al = id*xb[j];
- int l = best_index_iq3nl(block_values, al);
- float q = block_values[l];
- sumqx += w*q*xb[j];
- sumq2 += w*q*q;
+ float g = d * w * (xb[j] - d*block_values[L[j]]);
+ if (g > 0 && L[j] < 7) {
+ if (g > gmax) {
+ gmax = g; best_j = j; dir = 1;
+ }
+ }
+ else if (g < 0 && L[j] > 0) {
+ if (-g > gmax) {
+ gmax = -g; best_j = j; dir = -1;
+ }
+ }
}
- if (sumq2 > 0) d = sumqx/sumq2;
+ if (best_j < 0) break;
+
+ float w = weight[best_j];
+ sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]);
+ sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]);
+ L[best_j] += dir;
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ best_d = sumqx/sumq2; best = best_d*sumqx;
+ }
+ else if (iter > 8) break;
+
}
+
scales[ib] = d;
if (is_shifted) extra |= (1 << ib);