summaryrefslogtreecommitdiff
path: root/iqk-quantize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk-quantize.cpp')
-rw-r--r--iqk-quantize.cpp60
1 files changed, 17 insertions, 43 deletions
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp
index 1a672803..40eff93f 100644
--- a/iqk-quantize.cpp
+++ b/iqk-quantize.cpp
@@ -355,8 +355,8 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
}
void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) {
- assert(k % 64 == 0);
- const int64_t nb = k / 64;
+ //assert(k % 64 == 0);
+ //const int64_t nb = k / 64;
// Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher
//float amax = 0;
@@ -374,50 +374,24 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k)
// x += 64;
//}
- block_q8_K128 * yp = (block_q8_K128 *)y;
- for (int i = 0; i < nb/2; i++) {
- float max = 0;
- float amax = 0;
- for (int j = 0; j < 128; ++j) {
- float ax = fabsf(x[j]);
- if (ax > amax) {
- amax = ax; max = x[j];
+ float aux[4] = {0.f, 0.f, 0.f, 0.f};
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ for (int l = 0; l < 4; ++l) {
+ float ax = fabsf(x[j+4*i+l]);
+ aux[i] = std::max(aux[i], ax);
}
}
- if (!amax) {
- yp[i].d = 0;
- memset(yp[i].qs, 0, 128);
- x += 128;
- continue;
- }
- const float iscale = -127.f/max;
- for (int j = 0; j < 128; ++j) {
- int v = nearest_int(iscale*x[j]);
- yp[i].qs[j] = MIN(127, v);
- }
- yp[i].d = 1/iscale;
- x += 128;
}
- int i = 2*(nb/2);
- if (i < nb) {
- float max = 0;
- float amax = 0;
- for (int j = 0; j < 64; ++j) {
- float ax = fabsf(x[j]);
- if (ax > amax) {
- amax = ax; max = x[j];
- }
- }
- if (!amax) {
- yp[i/2].d = 0;
- memset(yp[i/2].qs, 0, 64);
- } else {
- const float iscale = -127.f/max;
- for (int j = 0; j < 64; ++j) {
- int v = nearest_int(iscale*x[j]);
- yp[i/2].qs[j] = MIN(127, v);
- }
- yp[i/2].d = 1/iscale;
+ float * dptr = (float *)y;
+ for (int i = 0; i < 4; ++i) {
+ dptr[i] = aux[i]/127;
+ aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
+ }
+ auto qs = (int8_t *)(dptr + 4);
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
}
}
}