diff options
Diffstat (limited to 'examples/quantize-stats/quantize-stats.cpp')
-rw-r--r-- | examples/quantize-stats/quantize-stats.cpp | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 34d05bf2..ff4e9bd4 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -256,6 +256,8 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo float mse0 = 0, mse = 0; auto compute = [&mutex, &counter, &mse0, &mse, values, row_size, nblock, nrows, n_per_row, chunk] () { std::vector<char> Q(row_size); + float diff[4]; + float xv[4]; float lmse0 = 0, lmse = 0; while (true) { std::unique_lock<std::mutex> lock(mutex); @@ -282,25 +284,41 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo for (int j = 0; j < 16; j += 2) { uint16_t v0 = *(const uint16_t *)(qs + j); int non = popcount(v0); - float diff1 = xb[j+ 0] - dl*values[qs[j+0] & 0xf]; - float diff2 = xb[j+16] - dl*values[qs[j+0] >> 4]; - float diff3 = xb[j+ 1] - dl*values[qs[j+1] & 0xf]; - float diff4 = xb[j+17] - dl*values[qs[j+1] >> 4]; - lmse0 += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; + xv[0] = xb[j+ 0]; xv[1] = xb[j+16]; xv[2] = xb[j+ 1]; xv[3] = xb[j+17]; + diff[0] = xv[0] - dl*values[qs[j+0] & 0xf]; + diff[1] = xv[1] - dl*values[qs[j+0] >> 4]; + diff[2] = xv[2] - dl*values[qs[j+1] & 0xf]; + diff[3] = xv[3] - dl*values[qs[j+1] >> 4]; + float diff4 = diff[0]*diff[0] + diff[1]*diff[1] + diff[2]*diff[2] + diff[3]*diff[3]; + lmse0 += diff4; if (non%2 == 0) { - lmse += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; + lmse += diff4; } else { float best = std::numeric_limits<float>::max(); - for (int k = 0; k < 16; k += 4) { - uint16_t v = v0 ^ (1 << k); - uint8_t v1 = v; - uint8_t v2 = v >> 8; - diff1 = xb[j+ 0] - dl*values[v1 & 0xf]; - diff2 = xb[j+16] - dl*values[v1 >> 4]; - diff3 = xb[j+ 1] - dl*values[v2 & 0xf]; - diff4 = xb[j+17] - dl*values[v2 >> 4]; - float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; - if (score < best) best = score; + //for (int k = 0; k < 16; k += 4) { + // uint16_t v = v0 ^ (1 << k); + // uint8_t v1 = v; + // uint8_t v2 = v >> 8; + // diff1 = xb[j+ 0] - dl*values[v1 & 0xf]; + // diff2 = xb[j+16] - dl*values[v1 >> 4]; + // diff3 = xb[j+ 1] - dl*values[v2 & 0xf]; + // diff4 = xb[j+17] - dl*values[v2 >> 4]; + // float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; + // if (score < best) best = score; + //} + for (int k = 0; k < 4; ++k) { + uint16_t v = (v0 >> 4*k) & 0xf; + auto pc = popcount(v); + if (v > 0 && popcount(v-1u) != pc) { + float this_diff = xv[k] - dl*values[v-1u]; + float score = diff4 - diff[k]*diff[k] + this_diff*this_diff; + if (score < best) best = score; + } + if (v < 15 && popcount(v + 1u) != pc) { + float this_diff = xv[k] - dl*values[v+1u]; + float score = diff4 - diff[k]*diff[k] + this_diff*this_diff; + if (score < best) best = score; + } } lmse += best; } |