diff options
Diffstat (limited to 'examples/quantize-stats/quantize-stats.cpp')
-rw-r--r-- | examples/quantize-stats/quantize-stats.cpp | 595 |
1 files changed, 579 insertions, 16 deletions
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 79905f54..a49ebd92 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -29,6 +29,7 @@ #include <thread> #include <mutex> #include <array> +#include <random> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -48,6 +49,10 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); } constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); } #endif +#ifdef __AVX2__ +#include <immintrin.h> +#endif + struct quantize_stats_params { std::string model = DEFAULT_MODEL_PATH; bool verbose = false; @@ -253,6 +258,575 @@ static void test_roundtrip_on_layer( } } +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static const int8_t scale_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +static std::vector<float> make_values(int nval, int n_per_val, float scale = 16.f) { + std::vector<float> result(nval*n_per_val); + uint16_t m16 = ggml_fp32_to_fp16(0.922f); + uint32_t m32 = (uint32_t(m16) << 16) | m16; + const uint32_t a = 89226354, b = 64248484; + float * data = result.data(); + for (int i = 0; i < nval; ++i) { + uint32_t x = i + 4096; + for (int k = 0; k < n_per_val; ++k) { + x = a*x + b; + uint32_t s = (x & 0b10001111111111111000111111111111) ^ m32; + float val = ggml_fp16_to_fp32(s & 65535) + ggml_fp16_to_fp32(s >> 16); + int ival = nearest_int(scale*val); + data[k] = ival; + } + data += n_per_val; + } + return result; +} + +#ifdef __AVX2__ +static inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +static inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +static __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +#endif + +const int8_t scale_index[241] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 +}; +inline int best_index_scale(const int8_t * values, float x) { + int ix = (int)x - values[0]; + if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15; + ix = scale_index[ix]; + return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15; +} +inline int best_index_iq4nl(const int8_t * values, float x) { return best_index_scale(values, x); } + +static float find_best_scale(int block_size, const float * xb, const float * weight, const int8_t * values, int ntry) { + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + return amax/96.f; //120.f; //127.f; + if (!amax) return 0.f; + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; + } + } + return d; +} + +static std::vector<float> cluster_points(const std::vector<float>& points, int ndim, int ncluster, int niter) { + if (points.size() % ndim != 0) { + printf("%s: bad input\n", __func__); return {}; + } + int npoint = points.size() / ndim; + if (npoint < 2*ncluster) { + printf("%s: bad input\n", __func__); return {}; + } + std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY)); + double Fo = 0; + for (int i = 0; i < npoint; ++i) { + auto v = points.data() + i*ndim; + for (int k = 0; k < ndim; ++k) { + Fo += v[k]*v[k]; + range[k].first = std::min(range[k].first, v[k]); + range[k].second = std::max(range[k].second, v[k]); + } + } + printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size()); + std::mt19937 rndm(1234); + float scale = 1.f/4294967296.f; + std::vector<float> result(ncluster*ndim); + for (int i = 0; i < ncluster; ++i) { + auto v = result.data() + i*ndim; + for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm(); + } + std::vector<float> sump(ncluster*ndim); + std::vector<int> counts(ncluster); + std::vector<int> which_cluster(npoint, -1); + double Flast = Fo; + for (int iter = 0; iter < niter; ++iter) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + int nchanged = 0; + double F = 0; + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + float best = INFINITY; int ibest = -1; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = result.data() + ndim*ic; + float dist2 = 0; + for (int k = 0; k < ndim; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best) { + best = dist2; ibest = ic; + } + } + if (ibest < 0) { printf("Oops.\n"); exit(1); } + F += best; + if (which_cluster[ip] != ibest) ++nchanged; + which_cluster[ip] = ibest; + ++counts[ibest]; + auto vc = sump.data() + ndim*ibest; + for (int k = 0; k < ndim; ++k) vc[k] += vp[k]; + } + if (nchanged == 0) break; + for (int ic = 0; ic < ncluster; ++ic) { + float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f; + auto vc = sump.data() + ndim*ic; + auto r = result.data() + ndim*ic; + for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm; + } + printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged); + if (iter > 1 && Flast/F - 1 < 1e-6) break; + Flast = F; + } + return result; +} + +static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { + constexpr int kNumVal = 1 << 15; + constexpr int kBlockSize = 32; + constexpr int kGroupSize = 8; + constexpr int kNg = kBlockSize/kGroupSize; + constexpr int kSuperBlockSize = 256; + static_assert(kNumVal%8 == 0); + static std::vector<float> codes, clusters; + static std::vector<std::vector<int>> p_in_cluster; + if (codes.empty()) { + codes = make_values(kNumVal, kGroupSize, 31.75f); + clusters = cluster_points(codes, kGroupSize, kNumVal/512, 200); + if (clusters.empty()) { printf("Oops\n"); exit(1); } + int ncluster = clusters.size()/kGroupSize; + p_in_cluster.resize(ncluster); + std::vector<int> which_cluster(4*kNumVal); + GGML_ASSERT(ncluster%8 == 0); + for (int ip = 0; ip < kNumVal; ++ip) { + auto vp = codes.data() + ip*kGroupSize; + float best[4] = {INFINITY, INFINITY, INFINITY, INFINITY}; + int ibest[4] = {-1, -1, -1, -1}; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = clusters.data() + ic*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best[0]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = best[0]; ibest[1] = ibest[0]; + best[0] = dist2; ibest[0] = ic; + } + else if (dist2 < best[1]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = dist2; ibest[1] = ic; + } + else if (dist2 < best[2]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = dist2; ibest[2] = ic; + } + else if (dist2 < best[3]) { + best[3] = dist2; ibest[3] = ic; + } + } + GGML_ASSERT(ibest[0] >= 0 && ibest[1] >= 0 && ibest[2] >= 0 && ibest[3] >= 0); + p_in_cluster[ibest[0]].push_back(ip); + p_in_cluster[ibest[1]].push_back(ip); + p_in_cluster[ibest[2]].push_back(ip); + p_in_cluster[ibest[3]].push_back(ip); + std::memcpy(which_cluster.data() + 4*ip, ibest, 4*sizeof(int)); + } + std::vector<std::pair<float, int>> extra; + extra.reserve(kNumVal); + for (int ic = 0; ic < ncluster; ++ic) { + auto& points = p_in_cluster[ic]; + if (!points.empty() && points.size()%8 == 0) continue; + extra.clear(); + auto vc = clusters.data() + ic*kGroupSize; + for (int ip = 0; ip < kNumVal; ++ip) { + if (which_cluster[4*ip] == ic || which_cluster[4*ip+1] == ic || which_cluster[4*ip+2] == ic || which_cluster[4*ip+3] == ic) continue; + auto vp = codes.data() + ip*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + extra.push_back(std::make_pair(dist2, ip)); + } + std::sort(extra.begin(), extra.end()); + int nadd = 8*((points.size()+7)/8) - points.size(); + for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second); + GGML_ASSERT(points.size()%8 == 0); + } + auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size(); + int nzero = 0; + for (auto& points : p_in_cluster) { + min = std::min(min, points.size()); + max = std::max(max, points.size()); + if (points.empty()) ++nzero; + } + printf("%s: prepared %d clusters\n", __func__, ncluster); + printf(" min number of points in a cluster: %d\n", int(min)); + printf(" max number of points in a cluster: %d\n", int(max)); + if (nzero > 0) { + printf(" there are %d empty clusters\n", nzero); + for (auto& points : p_in_cluster) { + if (!points.empty()) continue; + points.reserve(kNumVal); + for (int j = 0; j < kNumVal; ++j) points.push_back(j); // i.e., if we end iup picking an empty cluster, we just check all points + } + } + } + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int chunk = (nrows + 8*nthread - 1)/(8*nthread); + std::mutex mutex; + int counter = 0; + float mse = 0, mse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk] () { + double lmse = 0, lmse_q = 0; + std::vector<float> scales(n_per_row/kBlockSize); + std::vector<int> best_idx(n_per_row/kGroupSize); + std::vector<float> weight(kBlockSize, 1.f); + int ncluster = clusters.size() / kGroupSize; + while (true) { + std::unique_lock<std::mutex> lock(mutex); + int first = counter; counter += chunk; + if (first >= nrows) { + mse += lmse; mse_q += lmse_q; + return; + } + lock.unlock(); + int last = std::min(first + chunk, nrows); +#ifdef __AVX2__ + __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; +#endif + for (int row = first; row < last; ++row) { + auto xr = values + row*n_per_row; + float sigma2 = 0; + for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j]; + sigma2 /= n_per_row; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + auto xb = xr + kBlockSize*ib; + //for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5); + float id = d ? 1/d : 0.f; +#ifdef __AVX2__ + auto vid = _mm256_set1_ps(id); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 8*l; + auto wl = weight.data() + 8*l; + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl)); + auto vw = _mm256_loadu_ps(wl); + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; int jbest = -1; + for (int j = 0; j < ncluster; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(clusters.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + auto& points = p_in_cluster[jbest]; + if (points.empty()) { + printf("Oops: empty cluster %d\n", jbest); + auto vc = clusters.data() + kGroupSize*jbest; + printf("Cluster:\n"); + for (int j = 0; j < kGroupSize; ++j) printf("%d %g %g\n", j, vc[j], xl[j]); + GGML_ASSERT(false); + } + int jbest_cluster = jbest; + vbest = _mm256_set1_ps(INFINITY); + best_index = _mm256_set1_epi32(-1); + best = INFINITY; jbest = -1; + for (int j = 0; j < int(points.size()); j += 8) { + auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j)); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*points[j+i]); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + printf("Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + GGML_ASSERT(jbest >= 0); + best_idx[ib*kNg + l] = jbest; + } + auto vqx = _mm256_setzero_ps(); + auto vq2 = _mm256_setzero_ps(); + for (int l = 0; l < kNg; ++l) { + auto vx = _mm256_loadu_ps(xb+8*l); + auto vw = _mm256_loadu_ps(weight.data() + 8*l); + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*best_idx[ib*kNg + l]); + auto vqw = _mm256_mul_ps(vq, vw); + vqx = _mm256_fmadd_ps(vqw, vx, vqx); + vq2 = _mm256_fmadd_ps(vqw, vq, vq2); + } + auto sumqx = hsum_float_8(vqx); + auto sumq2 = hsum_float_8(vq2); + scales[ib] = sumq2 > 0 ? sumqx/sumq2 : 0.f; +#else +#endif + } + float amax_scale = std::abs(scales[0]); + float max_scale = scales[0]; + for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) { + float ax = std::abs(scales[ib]); + if (ax > amax_scale) { + amax_scale = ax; + max_scale = scales[ib]; + } + } + float d = max_scale/scale_values[0]; + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + int ls = best_index_scale(scale_values, id*scales[ib]); + float dl = d * scale_values[ls]; + auto xb = xr + kBlockSize*ib; + for (int l = 0; l < kNg; ++l) { + auto q = codes.data() + kGroupSize*best_idx[ib*kNg+l]; + for (int k = 0; k < kGroupSize; ++k) { + float diff1 = xb[kGroupSize*l + k] - scales[ib]*q[k]; + float diff2 = xb[kGroupSize*l + k] - dl*q[k]; + lmse += diff1*diff1; + lmse_q += diff2*diff2; + } + } + } + } + } + }; + std::vector<std::thread> workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + tot_mse += mse; + tot_mse_q += mse_q; + tot_elements += n_per_row*nrows; + printf("%s: %g %g %g %g\n", name, sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements), + sqrt(mse_q/(n_per_row*nrows)), sqrt(tot_mse_q/tot_elements)); +} + +static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { + constexpr int kNumVal = 1 << 12; + constexpr int kBlockSize = 8; + constexpr int kSuperBlockSize = 256; + static_assert(kNumVal%8 == 0); + auto codes = make_values(kNumVal, kBlockSize); + std::vector<float> sumq2i(kNumVal); + for (int j = 0; j < kNumVal; ++j) { + auto data = codes.data() + kBlockSize*j; + float sum = 0; for (int k = 0; k < kBlockSize; ++k) sum += data[k]*data[k]; + sumq2i[j] = sum > 0 ? 1/sum : 0.f;; + } + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int chunk = (nrows + 8*nthread - 1)/(8*nthread); + std::mutex mutex; + int counter = 0; + float mse = 0, mse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, &codes, &sumq2i, values, nrows, n_per_row, chunk] () { + float lmse = 0, lmse_q = 0; + std::vector<float> scales(n_per_row/kBlockSize); + std::vector<int> best_idx(n_per_row/kBlockSize); + while (true) { + std::unique_lock<std::mutex> lock(mutex); + int first = counter; counter += chunk; + if (first >= nrows) { + mse += lmse; mse_q += lmse_q; + return; + } + lock.unlock(); + int last = std::min(first + chunk, nrows); +#ifdef __AVX2__ + __m256 vx[kBlockSize/8]; + __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; +#endif + for (int row = first; row < last; ++row) { + auto xr = values + row*n_per_row; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + float best = 0, d = 0; int jbest = -1; + auto xb = xr + kBlockSize*ib; +#ifdef __AVX2__ + for (int l = 0; l < kBlockSize/8; ++l) { + vx[l] = _mm256_loadu_ps(xb+8*l); + } + auto vbest = _mm256_set1_ps(0.f); + auto best_index = _mm256_set1_epi32(-1); + for (int j = 0; j < kNumVal; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + sqx[i] = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize/8; ++l) { + auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l); + sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]); + } + } + auto sumqx = hsum_float_8x8(sqx); + auto score = _mm256_mul_ps(_mm256_mul_ps(sumqx, sumqx), _mm256_loadu_ps(sumq2i.data() + j)); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(idx, _mm256_castps_si256(mask)), _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_max_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + best = sx[0]; jbest = index[0]; + for (int j = 1; j < 8; ++j) { + if (sx[j] > best) { best = sx[j]; jbest = index[j]; } + } + auto qv = codes.data() + kBlockSize*jbest; + float sumqx = 0; + for (int k = 0; k < kBlockSize; ++k) sumqx += xb[k]*qv[k]; + d = sumqx*sumq2i[jbest]; +#else + for (int j = 0; j < kNumVal; ++j) { + if (!sumq2i[j]) continue; + auto qv = codes.data() + kBlockSize*j; + float sumqx = 0; + for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k]; + if (sumqx*sumqx*sumq2i[j] > best]) { + d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j; + } + } + auto qv = codes.data() + kBlockSize*jbest; +#endif + scales[ib] = d; + best_idx[ib] = jbest; + for (int k = 0; k < kBlockSize; ++k) { + float diff = xb[k] - d*qv[k]; + lmse += diff*diff; + } + } + float amax_scale = std::abs(scales[0]); + float max_scale = scales[0]; + for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) { + float ax = std::abs(scales[ib]); + if (ax > amax_scale) { + amax_scale = ax; + max_scale = scales[ib]; + } + } + float d = max_scale/scale_values[0]; + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + int ls = best_index_scale(scale_values, id*scales[ib]); + float dl = d * scale_values[ls]; + auto xb = xr + kBlockSize*ib; + auto qv = codes.data() + kBlockSize*best_idx[ib]; + for (int k = 0; k < kBlockSize; ++k) { + float diff = xb[k] - dl*qv[k]; + lmse_q += diff*diff; + } + } + } + } + }; + std::vector<std::thread> workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + tot_mse += mse; + tot_mse_q += mse_q; + tot_elements += n_per_row*nrows; + printf("%s: %g %g %g %g\n", name, sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements), + sqrt(mse_q/(n_per_row*nrows)), sqrt(tot_mse_q/tot_elements)); +} + static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) { int row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); int nblock = n_per_row/QK_K; @@ -302,17 +876,6 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo lmse += diff4; } else { float best = std::numeric_limits<float>::max(); - //for (int k = 0; k < 16; k += 4) { - // uint16_t v = v0 ^ (1 << k); - // uint8_t v1 = v; - // uint8_t v2 = v >> 8; - // diff1 = xb[j+ 0] - dl*values[v1 & 0xf]; - // diff2 = xb[j+16] - dl*values[v1 >> 4]; - // diff3 = xb[j+ 1] - dl*values[v2 & 0xf]; - // diff4 = xb[j+17] - dl*values[v2 >> 4]; - // float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; - // if (score < best) best = score; - //} for (int k = 0; k < 4; ++k) { uint16_t v = (v0 >> 4*k) & 0xf; auto pc = popcount(v); @@ -345,12 +908,12 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo printf("%s: %g %g %g\n", name, sqrt(mse0/(n_per_row*nrows)), sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements)); } -static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_elements) { +static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_mse_q, float& tot_elements) { if (!ggml_is_contiguous(t) || (t->type != GGML_TYPE_F32 && t->type != GGML_TYPE_F16 && t->type != GGML_TYPE_BF16)) { return; } if (t->type == GGML_TYPE_F32) { - analyze_iq4ks(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_mse_q, tot_elements); } else { std::vector<float> aux(t->ne[0]*t->ne[1]); if (t->type == GGML_TYPE_F16) { @@ -358,7 +921,7 @@ static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_elem } else { ggml_bf16_to_fp32_row((const ggml_bf16_t *)t->data, aux.data(), aux.size()); } - analyze_iq4ks(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_mse_q, tot_elements); } } @@ -542,7 +1105,7 @@ int main(int argc, char ** argv) { std::vector<float> output_scratch; if (analyze) { - float tot_mse = 0, tot_elements = 0; + float tot_mse = 0, tot_mse_q = 0, tot_elements = 0; for (const auto& kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; @@ -551,7 +1114,7 @@ int main(int argc, char ** argv) { // we never quantize those continue; } - analyze_iq4ks(kv_tensor.second, tot_mse, tot_elements); + analyze_iq4ks(kv_tensor.second, tot_mse, tot_mse_q, tot_elements); } return 0; } |