summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_quantize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp1386
1 files changed, 1386 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 93aa2180..c1f7a8e4 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -22,6 +22,8 @@
#include <algorithm>
#include <cstring>
#include <mutex>
+#include <random>
+#include <memory>
#include <thread>
#include <atomic>
#include <unordered_map>
@@ -7408,3 +7410,1387 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) {
}
}
+namespace {
+#ifdef __AVX2__
+__m128 hsum_float_4x4(__m128 * accm) {
+ accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2]));
+ accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3]));
+ return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1]));
+}
+__m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+__m256 hsum_float_4x8(__m256 * accm) {
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+#endif
+template <int block_size, int group_size, int num_bits, bool is_abs = false>
+class QuantizerIQKT {
+ static_assert(group_size == 8 || group_size == 4);
+ static_assert(block_size >= 8 && block_size%8 == 0);
+public:
+ constexpr static int kSuperBlockSize = QK_K;
+ constexpr static int kBlockSize = block_size;
+ constexpr static int kGroupSize = group_size;
+ constexpr static int kNg = kBlockSize/kGroupSize;
+ constexpr static int kNblock = kSuperBlockSize/kBlockSize;
+ constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8
+ constexpr static float kScale = 31.75f;
+ constexpr static bool kVerbose = false;
+
+ QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096);
+ const float * values() const { return m_values.data(); }
+
+ inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const;
+ inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const;
+ inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const;
+
+ static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ uint32_t x = i + offset;
+ for (int k = 0; k < kGroupSize; ++k) {
+ x = ka*x + kb;
+ uint32_t s = (x & kmask) ^ km32;
+ float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
+ if constexpr (is_abs) result[k] = scale*std::abs(val);
+ else result[k] = scale*val;
+ }
+ }
+
+ static inline int bin4(float x) {
+ if constexpr (is_abs) {
+ return x < 16.f ? 0 : x < 32.f ? 1 : x < 64.f ? 2 : 3;
+ } else {
+ return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3;
+ }
+ }
+ static inline int bin5(float x) {
+ if constexpr (is_abs) {
+ return x < 11.2f ? 0 : x < 24.f ? 1 : x < 39.f ? 2 : x < 58.f ? 3 : 4;
+ } else {
+ return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4;
+ }
+ }
+ inline int bin3(int idim, float x) const { return x < m_mid[2*idim+0] ? 0 : x < m_mid[2*idim+1] ? 1 : 2; }
+
+ static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) {
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ const float * xbl = x + ibl*kSuperBlockSize;
+ float * wbl = row_weights + ibl*kSuperBlockSize;
+
+ float sumx2 = 0;
+ for (int j = 0; j < kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j];
+ const float sigma2 = sigma2_scale*sumx2/kSuperBlockSize;
+
+ if (imatrix) {
+ const float * qw = imatrix + ibl*kSuperBlockSize;
+ for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = qw[j] * sqrtf(sigma2 + xbl[j]*xbl[j]);
+ } else {
+ for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = 0.25f*sigma2 + xbl[j]*xbl[j];
+ }
+ }
+ }
+private:
+ static std::vector<float> cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid);
+ static std::vector<std::vector<int>> finalize_clusters(int num_neighbours, const std::vector<float>& points, const std::vector<float>& clusters,
+ std::vector<std::vector<float>>& c_values);
+ std::vector<float> m_values;
+ std::vector<float> m_clusters;
+ std::vector<std::vector<int>> m_in_cluster;
+ std::vector<std::vector<float>> m_c_values;
+ float m_mid[4*kGroupSize];
+};
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
+ m_values.resize(kNumVal*kGroupSize);
+ float * data = m_values.data();
+ for (int i = 0; i < kNumVal; ++i) {
+ set_values(i, data, kScale, offset);
+ data += kGroupSize;
+ }
+ // Make 128 clusters.
+ // Note: we get a slightly better result by using 64 clusters
+ // at the expense of almost doubling the quantization time.
+ m_clusters = cluster_points(m_values, num_clusters, 200, m_mid);
+ GGML_ASSERT(!m_clusters.empty());
+ m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values);
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale(
+ const float * xb, const float * weight, const int * best_idx) const {
+ float sumqx = 0, sumq2 = 0;
+#ifdef __AVX2__
+ auto vqx = _mm256_setzero_ps();
+ auto vq2 = _mm256_setzero_ps();
+ for (int l = 0; l < kBlockSize; l += 8) {
+ auto vx = _mm256_loadu_ps(xb+l);
+ auto vw = _mm256_loadu_ps(weight+l);
+ auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
+ _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
+ _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
+ auto vqw = _mm256_mul_ps(vq, vw);
+ vqx = _mm256_fmadd_ps(vqw, vx, vqx);
+ vq2 = _mm256_fmadd_ps(vqw, vq, vq2);
+ }
+ sumqx = hsum_float_8(vqx);
+ sumq2 = hsum_float_8(vq2);
+#else
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + kGroupSize*l;
+ auto wl = weight + kGroupSize*l;
+ auto ql = m_values.data() + kGroupSize*best_idx[l];
+ for (int k = 0; k < kGroupSize; ++k) {
+ sumqx += wl[k]*ql[k]*xl[k];
+ sumq2 += wl[k]*ql[k]*ql[k];
+ }
+ }
+#endif
+ return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale(
+ const float * xb, const float * weight, const int * best_idx) const {
+ float sumqx = 0, sumx2 = 0;
+#ifdef __AVX2__
+ auto vqx = _mm256_setzero_ps();
+ auto vx2 = _mm256_setzero_ps();
+ for (int l = 0; l < kBlockSize; l += 8) {
+ auto vx = _mm256_loadu_ps(xb+l);
+ auto vw = _mm256_loadu_ps(weight+l);
+ auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
+ _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
+ _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
+ auto vxw = _mm256_mul_ps(vx, vw);
+ vx2 = _mm256_fmadd_ps(vxw, vx, vx2);
+ vqx = _mm256_fmadd_ps(vxw, vq, vqx);
+ }
+ sumqx = hsum_float_8(vqx);
+ sumx2 = hsum_float_8(vx2);
+#else
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + kGroupSize*l;
+ auto wl = weight + kGroupSize*l;
+ auto ql = m_values.data() + kGroupSize*best_idx[l];
+ for (int k = 0; k < kGroupSize; ++k) {
+ sumqx += wl[k]*ql[k]*xl[k];
+ sumx2 += wl[k]*xl[k]*xl[k];
+ }
+ }
+#endif
+ return sumx2 > 0 ? sumqx/sumx2 : 0.f;
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
+ if (!d) {
+ std::memset(best_idx, 0, kNg*sizeof(int));
+ return;
+ }
+ int ncluster = m_clusters.size()/kGroupSize;
+ float id = 1/d;
+#ifdef __AVX2__
+ if constexpr (kGroupSize == 8) {
+ __m256 sqx[8];
+ const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
+ float sx[8];
+ int index[8];
+ auto vid = _mm256_set1_ps(id);
+ auto add8 = _mm256_set1_epi32(8);
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + 8*l;
+ auto wl = weight + 8*l;
+ auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl));
+ auto vw = _mm256_loadu_ps(wl);
+ int jbest = -1;
+ if (kGroupSize == 8 && (ncluster == 256 || ncluster == 6561)) {
+ _mm256_store_ps(sx, vx);
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j);
+ } else {
+ int s = 1;
+ for (int j = 0; j < 8; ++j) { u += s*bin3(j, sx[j]); s *= 3; }
+ }
+ jbest = u;
+ } else {
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY;
+ auto idx = add_idx;
+ for (int j = 0; j < ncluster; j += 8) {
+ for (int i = 0; i < 8; ++i) {
+ auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_8x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ }
+ auto& points = m_in_cluster[jbest];
+ auto& values = points.empty() ? m_values : m_c_values[jbest];
+ int npoint = values.size()/kGroupSize;
+ GGML_ASSERT(npoint > 0 && npoint%8 == 0);
+ int jbest_cluster = jbest;
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ auto best = INFINITY; jbest = -1;
+ auto idx = add_idx;
+ for (int j = 0; j < npoint; j += 8) {
+ for (int i = 0; i < 8; ++i) {
+ auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_8x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ if (jbest < 0) {
+ fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
+ GGML_ASSERT(false);
+ }
+ best_idx[l] = points.empty() ? jbest : points[jbest];
+ }
+ } else {
+ __m256 sqx[4];
+ const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+ const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff));
+ float sx[8];
+ int index[8];
+ auto vid_p = _mm256_set1_ps(id);
+ auto add8 = _mm256_set1_epi32(8);
+ for (int l = 0; l < kNg; ++l) {
+ auto xl = xb + 4*l;
+ auto wl = weight + 4*l;
+ auto vx4 = _mm_loadu_ps(xl);
+ auto vx = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4));
+ auto vw4 = _mm_loadu_ps(wl);
+ auto vw = _mm256_set_m128(vw4, vw4);
+ int jbest = -1;
+ if (ncluster == 256 || ncluster == 625) {
+ _mm256_storeu_ps(sx, vx);
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k);
+ } else {
+ int l = 1;
+ for (int k = 0; k < 4; ++k) { u += bin5(sx[k])*l; l *= 5; }
+ }
+ jbest = u;
+ } else {
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY;
+ auto idx = add_idx;
+ for (int j = 0; j < ncluster; j += 8) {
+ for (int i = 0; i < 4; ++i) {
+ auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ vdiff = _mm256_and_ps(sign_bit, vdiff);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
+ }
+ auto score = hsum_float_4x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ }
+ auto& points = m_in_cluster[jbest];
+ auto& values = m_c_values[jbest];
+ GGML_ASSERT(!points.empty() && points.size()%8 == 0);
+ int jbest_cluster = jbest;
+ auto vbest = _mm256_set1_ps(INFINITY);
+ auto best_index = _mm256_set1_epi32(-1);
+ float best = INFINITY; jbest = -1;
+ auto idx = add_idx;
+ for (int j = 0; j < int(points.size()); j += 8) {
+ for (int i = 0; i < 4; ++i) {
+ auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i));
+ auto vdiff = _mm256_sub_ps(vq, vx);
+ sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
+ }
+ auto score = hsum_float_4x8(sqx);
+ auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
+ best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
+ _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
+ vbest = _mm256_min_ps(vbest, score);
+ idx = _mm256_add_epi32(idx, add8);
+ }
+ _mm256_store_ps(sx, vbest);
+ _mm256_store_si256((__m256i *)index, best_index);
+ for (int i = 0; i < 8; ++i) {
+ if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
+ }
+ if (jbest < 0) {
+ fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
+ GGML_ASSERT(false);
+ }
+ best_idx[l] = points[jbest];
+ }
+ }
+#else
+ // TODO
+ std::memset(best_idx, 0, kNg*sizeof(int));
+#endif
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours,
+ const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) {
+ int ncluster = clusters.size()/kGroupSize;
+ std::vector<std::vector<int>> p_in_cluster(ncluster);
+ std::vector<int> which_cluster(num_neighbours*kNumVal);
+ std::vector<int> ibest(num_neighbours);
+ std::vector<float> best(num_neighbours);
+ for (int ip = 0; ip < kNumVal; ++ip) {
+ auto vp = values.data() + ip*kGroupSize;
+ for (int j = 0; j < num_neighbours; ++j) {
+ best[j] = INFINITY; ibest[j] = -1;
+ }
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto vc = clusters.data() + ic*kGroupSize;
+ float dist2 = 0;
+ for (int k = 0; k < kGroupSize; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (dist2 < best[j]) {
+ for (int k = num_neighbours-1; k > j; --k) {
+ best[k] = best[k-1]; ibest[k] = ibest[k-1];
+ }
+ best[j] = dist2; ibest[j] = ic;
+ break;
+ }
+ }
+ }
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (ibest[j] < 0) {
+ printf("Oops: ibest[%d] = %d\n", j, ibest[j]);
+ }
+ GGML_ASSERT(ibest[j] >= 0);
+ p_in_cluster[ibest[j]].push_back(ip);
+ }
+ std::memcpy(which_cluster.data() + num_neighbours*ip, ibest.data(), num_neighbours*sizeof(int));
+ }
+ std::vector<std::pair<float, int>> extra;
+ extra.reserve(kNumVal);
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto& points = p_in_cluster[ic];
+ if (!points.empty() && points.size()%8 == 0) continue;
+ extra.clear();
+ auto vc = clusters.data() + ic*kGroupSize;
+ for (int ip = 0; ip < kNumVal; ++ip) {
+ bool can_add = true;
+ for (int j = 0; j < num_neighbours; ++j) {
+ if (which_cluster[num_neighbours*ip+j] == ic) { can_add = false; break; }
+ }
+ if (!can_add) continue;
+ auto vp = values.data() + ip*kGroupSize;
+ float dist2 = 0;
+ for (int k = 0; k < kGroupSize; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ extra.push_back(std::make_pair(dist2, ip));
+ }
+ std::sort(extra.begin(), extra.end());
+ int nadd = 8*((points.size()+7)/8) - points.size();
+ for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second);
+ GGML_ASSERT(points.size()%8 == 0);
+ }
+ auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size();
+ for (auto& points : p_in_cluster) {
+ min = std::min(min, points.size());
+ max = std::max(max, points.size());
+ }
+ c_values.resize(p_in_cluster.size());
+ for (int i = 0; i < int(p_in_cluster.size()); ++i) {
+ auto& points = p_in_cluster[i];
+ c_values[i].resize(points.size()*kGroupSize);
+ auto ptr = c_values[i].data();
+ for (auto j : points) {
+ std::memcpy(ptr, values.data() + j*kGroupSize, kGroupSize*sizeof(float));
+ ptr += kGroupSize;
+ }
+ }
+
+ if (kVerbose) {
+ printf("%s: prepared %d clusters\n", __func__, ncluster);
+ printf(" min number of points in a cluster: %d\n", int(min));
+ printf(" max number of points in a cluster: %d\n", int(max));
+ }
+ return p_in_cluster;
+}
+
+template <int block_size, int group_size, int num_bits, bool is_abs>
+std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
+ constexpr int ndim = kGroupSize;
+ GGML_ASSERT(points.size() % ndim == 0);
+ int npoint = points.size() / ndim;
+ GGML_ASSERT(npoint >= 2*ncluster);
+ std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY));
+ double Fo = 0;
+ for (int i = 0; i < npoint; ++i) {
+ auto v = points.data() + i*ndim;
+ for (int k = 0; k < ndim; ++k) {
+ Fo += v[k]*v[k];
+ range[k].first = std::min(range[k].first, v[k]);
+ range[k].second = std::max(range[k].second, v[k]);
+ }
+ }
+ if (kVerbose) printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size());
+ if constexpr (is_abs) {
+ std::vector<int> P(npoint);
+ for (int idim = 0; idim < ndim; ++idim) {
+ for (int ip = 0; ip < npoint; ++ip) P[ip] = points[ip*ndim+idim];
+ std::sort(P.begin(), P.end());
+ if (ndim == 8 && ncluster == 6561) {
+ mid[2*idim + 0] = P[npoint/3];
+ mid[2*idim + 1] = P[2*npoint/3];
+ } else {
+ mid[idim] = npoint%2 == 0 ? 0.5f*(P[npoint/2] + P[npoint/2-1]) : P[npoint/2];
+ if (kVerbose) printf("%s: mid[%d] = %g\n", __func__, idim, mid[idim]);
+ }
+ }
+ } else {
+ for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second);
+ }
+ std::vector<float> sump(ncluster*ndim);
+ std::vector<int> counts(ncluster);
+ std::vector<float> result(ncluster*ndim);
+ if (ndim == 8 && (ncluster == 256 || ncluster == 6561)) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k);
+ } else {
+ int s = 1;
+ for (int k = 0; k < ndim; ++k) {
+ int bin = vp[k] < mid[2*k+0] ? 0 : vp[k] < mid[2*k+1] ? 1 : 2;
+ u += s*bin; s *= 3;
+ }
+ }
+ ++counts[u];
+ for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
+ }
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) {
+ printf("%s: Oops. Cluster %d has no points\n", __func__, ic);
+ GGML_ABORT("fatal error");
+ }
+ for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic];
+ }
+ return result;
+ }
+ else if (ndim == 4 && (ncluster == 256 || ncluster == 625)) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ uint16_t u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k);
+ } else {
+ int s = 1;
+ for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; }
+ }
+ if (u >= int(counts.size())) {
+ printf("Oops: u = %u, vp = %g, %g, %g, %g\n", u, vp[0], vp[1], vp[2], vp[3]);
+ u = 0;
+ if (ncluster == 256) {
+ for (int k = 0; k < ndim; ++k) {
+ auto bin = bin4(vp[k]); u |= (bin << 2*k);
+ printf(" bin[%d] = %d, u = %u", k, bin, u);
+ }
+ } else {
+ for (int k = 0; k < ndim; ++k) printf(" bin[%d] = %d", k, bin5(vp[k]));
+ }
+ printf("\n");
+ GGML_ABORT("fatal error");
+ }
+ ++counts[u];
+ for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
+ }
+ int nzero = 0;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) {
+ ++nzero;
+ printf("%s: Oops. Cluster %d has no points: ", __func__, ic);
+ for (int k = 0; k < ndim; ++k) {
+ int l = (ic >> 2*k) & 3;
+ printf(" %d", l);
+ }
+ printf("\n");
+ } else {
+ for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic];
+ }
+ }
+ if (nzero > 0) printf("%s: %d out of %d clusters dir not have any points\n", __func__, nzero, ncluster);
+ return result;
+ }
+ std::mt19937 rndm(1234);
+ float scale = 1.f/4294967296.f;
+ for (int i = 0; i < ncluster; ++i) {
+ auto v = result.data() + i*ndim;
+ for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm();
+ }
+ std::vector<int> which_cluster(npoint, -1);
+ double Flast = Fo;
+ for (int iter = 0; iter < niter; ++iter) {
+ std::memset(sump.data(), 0, sump.size()*sizeof(float));
+ std::memset(counts.data(), 0, counts.size()*sizeof(int));
+ int nchanged = 0;
+ double F = 0;
+ for (int ip = 0; ip < npoint; ++ip) {
+ auto vp = points.data() + ndim*ip;
+ float best = INFINITY; int ibest = -1;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ auto vc = result.data() + ndim*ic;
+ float dist2 = 0;
+ for (int k = 0; k < ndim; ++k) {
+ float d = vp[k] - vc[k]; dist2 += d*d;
+ }
+ if (dist2 < best) {
+ best = dist2; ibest = ic;
+ }
+ }
+ if (ibest < 0) {
+ printf("Oops(iteration %d) - failed to find cluster for point", iter);
+ for (int k = 0; k < ndim; ++k) printf(" %g", vp[k]);
+ printf("\nHave %d clusters\n", ncluster);
+ }
+ GGML_ASSERT(ibest >= 0);
+ F += best;
+ if (which_cluster[ip] != ibest) ++nchanged;
+ which_cluster[ip] = ibest;
+ ++counts[ibest];
+ auto vc = sump.data() + ndim*ibest;
+ for (int k = 0; k < ndim; ++k) vc[k] += vp[k];
+ }
+ if (nchanged == 0) break;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f;
+ auto vc = sump.data() + ndim*ic;
+ auto r = result.data() + ndim*ic;
+ for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm;
+ }
+ if (kVerbose) printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged);
+ if (iter > 1 && Flast/F - 1 < 1e-6) break;
+ Flast = F;
+ }
+ int nzero = 0;
+ for (int ic = 0; ic < ncluster; ++ic) {
+ if (!counts[ic]) ++nzero;
+ }
+ if (nzero > 0) printf("%s: there are %d empty clusters\n", __func__, nzero);
+ return result;
+}
+
+// ========================================== iq2_kt ====================================================
+
+using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>;
+
+const QuantizerIQ2KT& iq2kt_quantizer() {
+ static std::mutex mutex;
+ static std::unique_ptr<QuantizerIQ2KT> quantizer;
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!quantizer) quantizer = std::make_unique<QuantizerIQ2KT>(256, 8);
+ return *quantizer;
+}
+
+void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights,
+ float * qtmp) {
+
+ constexpr float kSigmaScale = 2.0f;
+ using Q = QuantizerIQ2KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ float * dptr = (float *)vy;
+
+ block_iq2_kt * y = (block_iq2_kt *)(dptr + 1);
+
+ int best_idx[2*Q::kNg];
+
+ auto& quantizer = iq2kt_quantizer();
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_scale = 0, max_scale = 0;
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq2_kt));
+
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float ax = std::abs(xb[j]);
+ amax = std::max(amax, ax);
+ }
+ quantizer.find_best_match( amax/96.f, xb, weight, best_idx);
+ auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
+ quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg);
+ auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg);
+
+ auto idx = best_idx;
+ if (score_p > score_m) scales[ib] = dp;
+ else {
+ scales[ib] = dm; idx += Q::kNg;
+ }
+ auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto q = quantizer.values() + idx[ig]*Q::kGroupSize;
+ for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j];
+ qt += Q::kGroupSize;
+ }
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ if (!max_scale) {
+ *dptr = 0;
+ return;
+ }
+
+ float d = max_scale/iq4k_values[0];
+ float best = 0;
+ for (int itry = -9; itry <= 9; ++itry) {
+ float id = (itry + iq4k_values[0])/max_scale;
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xb = x + ibl*Q::kSuperBlockSize;
+ const float * qb = qtmp + ibl*Q::kSuperBlockSize;
+ const float * wb = all_weights + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int ls = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ float dl = iq4k_values[ls];
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float q = dl*qb[j];
+ sumqx += wb[j]*xb[j]*q;
+ sumq2 += wb[j]*q*q;
+ }
+ xb += Q::kBlockSize;
+ wb += Q::kBlockSize;
+ qb += Q::kBlockSize;
+ }
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx;
+ }
+ }
+
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]);
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+
+ *dptr = d;
+ if (!d) return;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ auto qs = (uint16_t *)y[ibl].ql;
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ int ls = iq4k_values[(y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf];
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xb, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ qs[j] = best_idx[j];
+ auto xl = xb + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize;
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ qs += Q::kNg;
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ *dptr = d;
+ if (!d) return;
+ } else {
+ break;
+ }
+
+ if (false) {
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ auto qs = (uint16_t *)y[ibl].ql;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j];
+ auto pair = quantizer.find_best_scale(xb, weight, best_idx);
+ scales[ib] = pair.first;
+ }
+ }
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]);
+ int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]);
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+ }
+
+ }
+
+}
+}
+
+void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq2_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq2_kt * y = (block_iq2_kt *)vy;
+ quantize_row_iq2_kt_ref(x, y, k);
+}
+
+size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize);
+ std::vector<float> weights(n_per_row);
+ std::vector<float> xtmp(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
+ assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
+ const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = *dptr * QuantizerIQ2KT::kScale;
+ x = (const block_iq2_kt *)(dptr + 1);
+ auto& deq = iq2kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto yl = y + ibl*QuantizerIQ2KT::kSuperBlockSize;
+ auto yh = yl + QuantizerIQ2KT::kSuperBlockSize/2;
+ const uint16_t * ql = (const uint16_t *)x[ibl].ql;
+ const uint16_t * qh = ql + QuantizerIQ2KT::kNg*QuantizerIQ2KT::kNblock/2;
+ for (int ib = 0; ib < QuantizerIQ2KT::kNblock/2; ++ib) {
+ float sl = d * iq4k_values[x[ibl].scales[ib] & 0xf];
+ float sh = d * iq4k_values[x[ibl].scales[ib] >> 4];
+ for (int ig = 0; ig < QuantizerIQ2KT::kNg; ++ig) {
+ deq.set_values(ql[ig], yl, sl);
+ deq.set_values(qh[ig], yh, sh);
+ yl += QuantizerIQ2KT::kGroupSize;
+ yh += QuantizerIQ2KT::kGroupSize;
+ }
+ ql += QuantizerIQ2KT::kNg;
+ qh += QuantizerIQ2KT::kNg;
+ }
+ }
+}
+
+void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}
+
+namespace {
+
+using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>;
+const QuantizerIQ3KT& iq3kt_quantizer() {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unique_ptr<QuantizerIQ3KT> quantizer;
+ if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 8);
+ return *quantizer;
+}
+
+void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales,
+ float * all_weights, float * qtmp) {
+
+ constexpr float kSigmaScale = 2.0f;
+ constexpr float kStep = 8.0f;
+
+ using Q = QuantizerIQ3KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+
+ float * dptr = (float *)vy;
+
+ block_iq3_kt * y = (block_iq3_kt *)(dptr + 1);
+
+ int best_idx[2*Q::kNg];
+
+ auto& quantizer = iq3kt_quantizer();
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ float amax_row = 0;
+ for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j]));
+ if (!amax_row) {
+ *dptr = 0.f;
+ std::memset(y, 0, nblock*sizeof(block_iq3_kt));
+ return;
+ }
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_scale = 0, max_scale = 0;
+
+ float xaux[Q::kBlockSize];
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq3_kt));
+
+ auto scales = all_scales + ibl*Q::kNblock;
+ auto xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float ax = std::abs(xb[j]);
+ xaux[j] = ax;
+ amax = std::max(amax, ax);
+ }
+ scales[ib] = 0;
+ if (!amax) continue;
+
+ //quantizer.find_best_match(amax/96.f, xaux, weight, best_idx+Q::kNg);
+ //scales[ib] = quantizer.find_best_scale(xaux, weight, best_idx+Q::kNg).first;
+
+ float scale_0 = std::max(84.f, 123.f*amax/amax_row);
+ //float scale_0 = std::max(64.f, 123.f*amax/amax_row);
+ float best = 0;
+ for (int itry = -3; itry <= 3; ++itry) {
+ quantizer.find_best_match(amax/(scale_0 + kStep*itry), xaux, weight, best_idx);
+ auto [d, score] = quantizer.find_best_scale(xaux, weight, best_idx);
+ if (score > best) {
+ best = score;
+ scales[ib] = d;
+ std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int));
+ }
+ }
+
+ auto xt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto q = quantizer.values() + Q::kGroupSize*best_idx[Q::kNg+ig];
+ for (int j = 0; j < Q::kGroupSize; ++j) *xt++ = q[j];
+ }
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ GGML_ASSERT(max_scale >= 0);
+ float d = max_scale/15;
+ float best = 0;
+ for (int itry = -9; itry <= 9; ++itry) {
+ float id = (itry*0.2f + 15)/max_scale;
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * xb = x + ibl*Q::kSuperBlockSize;
+ const float * qb = qtmp + ibl*Q::kSuperBlockSize;
+ const float * wb = all_weights + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int ls = nearest_int(id*scales[ib]);
+ ls = std::max(0, std::min(15, ls));
+ float dl = ls;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ float q = dl*qb[j];
+ sumqx += wb[j]*std::abs(xb[j])*q;
+ sumq2 += wb[j]*q*q;
+ }
+ xb += Q::kBlockSize;
+ wb += Q::kBlockSize;
+ qb += Q::kBlockSize;
+ }
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d*sumqx;
+ }
+ }
+
+ float id = d ? 1/d : 0.f;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*Q::kNblock;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ int ls1 = nearest_int(id*scales[ib]);
+ int ls2 = nearest_int(id*scales[ib + Q::kNblock/2]);
+ ls1 = std::max(0, std::min(15, ls1));
+ ls2 = std::max(0, std::min(15, ls2));
+ y[ibl].scales[ib] = ls1 | (ls2 << 4);
+ }
+ }
+
+ *dptr = d;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ uint16_t * ql = (uint16_t *)y[ibl].ql;
+
+ std::memset(y[ibl].qh, 0, kNumGroups/2);
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * xb = xbl + Q::kBlockSize*ib;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ xaux[j] = std::abs(xb[j]);
+ if (xb[j] < 0) y[ibl].qh[j] |= (1 << ib);
+ }
+ int ls = (y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf;
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xaux, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ ql[ib*Q::kNg+j] = best_idx[j];
+ auto xl = xaux + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize;
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ *dptr = d;
+ if (!d) break;
+ } else {
+ break;
+ }
+ }
+}
+}
+
+void quantize_row_iq3_kt_ref(const float * x, block_iq3_kt * y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq3_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq3_kt(const float * x, void * vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq3_kt * y = (block_iq3_kt *)vy;
+ quantize_row_iq3_kt_ref(x, y, k);
+}
+
+size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ3KT::kBlockSize);
+ std::vector<float> weights(n_per_row), xtmp(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq3_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
+ using Q = QuantizerIQ3KT;
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+ assert(k % Q::kSuperBlockSize == 0);
+ const int nb = k / Q::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = *dptr * Q::kScale;
+ x = (const block_iq3_kt *)(dptr + 1);
+ auto& deq = iq3kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto yl = y + ibl*Q::kSuperBlockSize;
+ auto yh = yl + Q::kSuperBlockSize/2;
+ auto qll = (const uint16_t *)x[ibl].ql;
+ auto qlh = qll + kNumGroups/2;
+ int jj = 0;
+ for (int ib = 0; ib < Q::kNblock/2; ++ib) {
+ float sl = d * (x[ibl].scales[ib] & 0xf);
+ float sh = d * (x[ibl].scales[ib] >> 4);
+ uint8_t l_mask = 1 << ib;
+ uint8_t h_mask = l_mask << (Q::kNblock/2);
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ deq.set_values(qll[jj], yl, sl);
+ deq.set_values(qlh[jj], yh, sh);
+ for (int j = 0; j < Q::kGroupSize; ++j) {
+ if (x[ibl].qh[ig*Q::kGroupSize+j] & l_mask) yl[j] = -yl[j];
+ if (x[ibl].qh[ig*Q::kGroupSize+j] & h_mask) yh[j] = -yh[j];
+ }
+ yl += Q::kGroupSize;
+ yh += Q::kGroupSize;
+ ++jj;
+ }
+ }
+ }
+}
+
+void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}
+
+// ======================================== iq4_kt
+
+namespace{
+
+using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>;
+
+const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unique_ptr<QuantizerIQ4KT> quantizer1;
+ static std::unique_ptr<QuantizerIQ4KT> quantizer2;
+ if (with_offset) {
+ if (!quantizer2) quantizer2 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096+32768);
+ return *quantizer2;
+ }
+ if (!quantizer1) quantizer1 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096);
+ return *quantizer1;
+}
+
+void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) {
+
+ constexpr float kSigmaScale = 2.0f;
+ constexpr int kNtry = 2;
+ using Q = QuantizerIQ4KT;
+
+ static_assert(Q::kNumVal%8 == 0);
+
+ float * dptr = (float *)vy;
+
+ block_iq4_kt * y = (block_iq4_kt *)(dptr + 2);
+
+ auto& quantizer1 = iq4kt_quantizer();
+ auto& quantizer2 = iq4kt_quantizer(true);
+
+ int nblock = n_per_row / Q::kSuperBlockSize;
+
+ Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+
+ float amax_row = 0, row_av = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ row_av += x[j];
+ amax_row = std::max(amax_row, std::abs(x[j]));
+ }
+ row_av /= n_per_row;
+ dptr[1] = row_av;
+ if (!amax_row) {
+ dptr[0] = 0.f;
+ std::memset(y, 0, nblock*sizeof(block_iq4_kt));
+ return;
+ }
+
+ int best_idx[2*Q::kNg];
+ float xaux[Q::kBlockSize];
+
+ float amax_scale = 0, max_scale = 0;
+
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq4_kt));
+
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ float amax = 0;
+ for (int j = 0; j < Q::kBlockSize; ++j) {
+ xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ float ax = std::abs(xaux[j]);
+ amax = std::max(amax, ax);
+ }
+ if (!amax) {
+ scales[ib] = 0;
+ continue;
+ }
+ float best = 0;
+ float scale_0 = std::max(92.f, 127.f*amax/amax_row);
+ for (int itry = -kNtry; itry <= kNtry; ++itry) {
+ quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx);
+ if (score_p > best) {
+ best = score_p; scales[ib] = dp;
+ }
+ quantizer1.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dm, score_m] = quantizer1.find_best_scale(xaux, weight, best_idx);
+ if (score_m > best) {
+ best = score_m; scales[ib] = dm;
+ }
+ }
+
+ quantizer2.find_best_match(scales[ib], xaux, weight, best_idx);
+ auto [d, score] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score > best) {
+ scales[ib] = d;
+ y[ibl].qs[ib] = 1;
+ }
+ bool with_offset = false;
+ for (int itry = -kNtry; itry <= kNtry; ++itry) {
+ quantizer2.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dp, score_p] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score_p > best) {
+ best = score_p; scales[ib] = dp; with_offset = true;
+ }
+ quantizer2.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx);
+ auto [dm, score_m] = quantizer2.find_best_scale(xaux, weight, best_idx);
+ if (score_m > best) {
+ best = score_m; scales[ib] = dm; with_offset = true;
+ }
+ }
+ if (with_offset) y[ibl].qs[ib] = 1;
+
+ float abs_scale = std::abs(scales[ib]);
+ if (abs_scale > amax_scale) {
+ amax_scale = abs_scale;
+ max_scale = scales[ib];
+ }
+ }
+
+ }
+
+ float d = -max_scale/64;
+
+ dptr[0] = d;
+ if (!d) return;
+
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+
+ for (int iloop = 0; iloop < 1; ++iloop) {
+
+ const float id = 1/d;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+
+ // high 3 bits + scales
+ // each block of 32 needs 8 x 3 (high bits) + 1 x 8 (scale) = 32 bits = 1 x uint32_t
+ // we have 8 blocks
+ auto shb = y[ibl].qs; // high 3 bits + scales
+ auto ql = (uint8_t *)(shb + Q::kNblock);
+ auto qh = ql + kNumGroups;
+ std::memset(qh, 0, kNumGroups/2);
+ const float * xbl = x + ibl*Q::kSuperBlockSize;
+ auto scales = all_scales + ibl*Q::kNblock;
+
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1;
+ const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
+ for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ int ls = nearest_int(id*scales[ib]);
+ ls = std::min(ls, 63);
+ *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1);
+ float dl = d*ls;
+ quantizer.find_best_match(dl, xaux, weight, best_idx);
+
+ for (int j = 0; j < Q::kNg; ++j) {
+ shb[ib] |= ((best_idx[j] >> 12) << (8 + 3*j));
+ ql[Q::kNg*ib + j] = best_idx[j] & 255;
+ qh[(Q::kNg*ib + j)%(kNumGroups/2)] |= ((best_idx[j] >> 8) & 0xf) << 4*((Q::kNg*ib + j)/(kNumGroups/2));
+ auto xl = xaux + Q::kGroupSize*j;
+ auto wl = weight + Q::kGroupSize*j;
+ auto ql = quantizer.values() + Q::kGroupSize*best_idx[j];
+ for (int k = 0; k < Q::kGroupSize; ++k) {
+ float q = ql[k]*ls;
+ sumqx += wl[k]*xl[k]*q;
+ sumq2 += wl[k]*q*q;
+ }
+ }
+ }
+ }
+ if (sumq2 > 0) {
+ d = sumqx/sumq2;
+ dptr[0] = d;
+ if (!d) break;
+ } else {
+ break;
+ }
+ }
+}
+}
+
+void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq4_kt(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq4_kt * y = (block_iq4_kt *)vy;
+ quantize_row_iq4_kt_ref(x, y, k);
+}
+
+size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row);
+ std::vector<float> scales(n_per_row/QuantizerIQ4KT::kBlockSize);
+ std::vector<float> weights(n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq4_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data());
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
+ using Q = QuantizerIQ4KT;
+ assert(k % Q::kSuperBlockSize == 0);
+ constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
+ const int nb = k / Q::kSuperBlockSize;
+ const float * dptr = (const float *)x;
+ const float d = dptr[0] * Q::kScale;
+ const float row_av = dptr[1];
+ x = (const block_iq4_kt *)(dptr + 2);
+ auto& deq = iq4kt_quantizer();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto shb = x[ibl].qs;
+ auto ql = (const uint8_t *)(shb + Q::kNblock);
+ auto qh = ql + kNumGroups;
+ for (int ib = 0; ib < Q::kNblock; ++ib) {
+ int offset = shb[ib] & 1 ? 32768 + 4096 : 4096;
+ //auto& deq = shb[ib] & 1 ? deq2 : deq1;
+ int ls = int((shb[ib] & 0xff) >> 1) - 64;
+ float sl = d * ls;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ int jj = ib*Q::kNg+ig;
+ uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12);
+ deq.set_values(idx, y, sl, offset);
+ for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av;
+ y += Q::kGroupSize;
+ }
+ }
+ }
+}
+
+void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+}