diff options
Diffstat (limited to 'iqk-quantize.cpp')
-rw-r--r-- | iqk-quantize.cpp | 60 |
1 files changed, 17 insertions, 43 deletions
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 1a672803..40eff93f 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -355,8 +355,8 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si } void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) { - assert(k % 64 == 0); - const int64_t nb = k / 64; + //assert(k % 64 == 0); + //const int64_t nb = k / 64; // Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher //float amax = 0; @@ -374,50 +374,24 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) // x += 64; //} - block_q8_K128 * yp = (block_q8_K128 *)y; - for (int i = 0; i < nb/2; i++) { - float max = 0; - float amax = 0; - for (int j = 0; j < 128; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; + float aux[4] = {0.f, 0.f, 0.f, 0.f}; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) { + float ax = fabsf(x[j+4*i+l]); + aux[i] = std::max(aux[i], ax); } } - if (!amax) { - yp[i].d = 0; - memset(yp[i].qs, 0, 128); - x += 128; - continue; - } - const float iscale = -127.f/max; - for (int j = 0; j < 128; ++j) { - int v = nearest_int(iscale*x[j]); - yp[i].qs[j] = MIN(127, v); - } - yp[i].d = 1/iscale; - x += 128; } - int i = 2*(nb/2); - if (i < nb) { - float max = 0; - float amax = 0; - for (int j = 0; j < 64; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; - } - } - if (!amax) { - yp[i/2].d = 0; - memset(yp[i/2].qs, 0, 64); - } else { - const float iscale = -127.f/max; - for (int j = 0; j < 64; ++j) { - int v = nearest_int(iscale*x[j]); - yp[i/2].qs[j] = MIN(127, v); - } - yp[i/2].d = 1/iscale; + float * dptr = (float *)y; + for (int i = 0; i < 4; ++i) { + dptr[i] = aux[i]/127; + aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f; + } + auto qs = (int8_t *)(dptr + 4); + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); } } } |