diff options
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 1386 |
1 files changed, 1386 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 93aa2180..c1f7a8e4 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -22,6 +22,8 @@ #include <algorithm> #include <cstring> #include <mutex> +#include <random> +#include <memory> #include <thread> #include <atomic> #include <unordered_map> @@ -7408,3 +7410,1387 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) { } } +namespace { +#ifdef __AVX2__ +__m128 hsum_float_4x4(__m128 * accm) { + accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2])); + accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3])); + return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1])); +} +__m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +__m256 hsum_float_4x8(__m256 * accm) { + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +#endif +template <int block_size, int group_size, int num_bits, bool is_abs = false> +class QuantizerIQKT { + static_assert(group_size == 8 || group_size == 4); + static_assert(block_size >= 8 && block_size%8 == 0); +public: + constexpr static int kSuperBlockSize = QK_K; + constexpr static int kBlockSize = block_size; + constexpr static int kGroupSize = group_size; + constexpr static int kNg = kBlockSize/kGroupSize; + constexpr static int kNblock = kSuperBlockSize/kBlockSize; + constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 + constexpr static float kScale = 31.75f; + constexpr static bool kVerbose = false; + + QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); + const float * values() const { return m_values.data(); } + + inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const; + inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const; + inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const; + + static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t x = i + offset; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + uint32_t s = (x & kmask) ^ km32; + float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + if constexpr (is_abs) result[k] = scale*std::abs(val); + else result[k] = scale*val; + } + } + + static inline int bin4(float x) { + if constexpr (is_abs) { + return x < 16.f ? 0 : x < 32.f ? 1 : x < 64.f ? 2 : 3; + } else { + return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3; + } + } + static inline int bin5(float x) { + if constexpr (is_abs) { + return x < 11.2f ? 0 : x < 24.f ? 1 : x < 39.f ? 2 : x < 58.f ? 3 : 4; + } else { + return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4; + } + } + inline int bin3(int idim, float x) const { return x < m_mid[2*idim+0] ? 0 : x < m_mid[2*idim+1] ? 1 : 2; } + + static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) { + for (int ibl = 0; ibl < nblock; ++ibl) { + + const float * xbl = x + ibl*kSuperBlockSize; + float * wbl = row_weights + ibl*kSuperBlockSize; + + float sumx2 = 0; + for (int j = 0; j < kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = sigma2_scale*sumx2/kSuperBlockSize; + + if (imatrix) { + const float * qw = imatrix + ibl*kSuperBlockSize; + for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = qw[j] * sqrtf(sigma2 + xbl[j]*xbl[j]); + } else { + for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = 0.25f*sigma2 + xbl[j]*xbl[j]; + } + } + } +private: + static std::vector<float> cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid); + static std::vector<std::vector<int>> finalize_clusters(int num_neighbours, const std::vector<float>& points, const std::vector<float>& clusters, + std::vector<std::vector<float>>& c_values); + std::vector<float> m_values; + std::vector<float> m_clusters; + std::vector<std::vector<int>> m_in_cluster; + std::vector<std::vector<float>> m_c_values; + float m_mid[4*kGroupSize]; +}; + +template <int block_size, int group_size, int num_bits, bool is_abs> +QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { + m_values.resize(kNumVal*kGroupSize); + float * data = m_values.data(); + for (int i = 0; i < kNumVal; ++i) { + set_values(i, data, kScale, offset); + data += kGroupSize; + } + // Make 128 clusters. + // Note: we get a slightly better result by using 64 clusters + // at the expense of almost doubling the quantization time. + m_clusters = cluster_points(m_values, num_clusters, 200, m_mid); + GGML_ASSERT(!m_clusters.empty()); + m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values); +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale( + const float * xb, const float * weight, const int * best_idx) const { + float sumqx = 0, sumq2 = 0; +#ifdef __AVX2__ + auto vqx = _mm256_setzero_ps(); + auto vq2 = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); + auto vqw = _mm256_mul_ps(vq, vw); + vqx = _mm256_fmadd_ps(vqw, vx, vqx); + vq2 = _mm256_fmadd_ps(vqw, vq, vq2); + } + sumqx = hsum_float_8(vqx); + sumq2 = hsum_float_8(vq2); +#else + for (int l = 0; l < kNg; ++l) { + auto xl = xb + kGroupSize*l; + auto wl = weight + kGroupSize*l; + auto ql = m_values.data() + kGroupSize*best_idx[l]; + for (int k = 0; k < kGroupSize; ++k) { + sumqx += wl[k]*ql[k]*xl[k]; + sumq2 += wl[k]*ql[k]*ql[k]; + } + } +#endif + return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale( + const float * xb, const float * weight, const int * best_idx) const { + float sumqx = 0, sumx2 = 0; +#ifdef __AVX2__ + auto vqx = _mm256_setzero_ps(); + auto vx2 = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); + auto vxw = _mm256_mul_ps(vx, vw); + vx2 = _mm256_fmadd_ps(vxw, vx, vx2); + vqx = _mm256_fmadd_ps(vxw, vq, vqx); + } + sumqx = hsum_float_8(vqx); + sumx2 = hsum_float_8(vx2); +#else + for (int l = 0; l < kNg; ++l) { + auto xl = xb + kGroupSize*l; + auto wl = weight + kGroupSize*l; + auto ql = m_values.data() + kGroupSize*best_idx[l]; + for (int k = 0; k < kGroupSize; ++k) { + sumqx += wl[k]*ql[k]*xl[k]; + sumx2 += wl[k]*xl[k]*xl[k]; + } + } +#endif + return sumx2 > 0 ? sumqx/sumx2 : 0.f; +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { + if (!d) { + std::memset(best_idx, 0, kNg*sizeof(int)); + return; + } + int ncluster = m_clusters.size()/kGroupSize; + float id = 1/d; +#ifdef __AVX2__ + if constexpr (kGroupSize == 8) { + __m256 sqx[8]; + const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; + auto vid = _mm256_set1_ps(id); + auto add8 = _mm256_set1_epi32(8); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 8*l; + auto wl = weight + 8*l; + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl)); + auto vw = _mm256_loadu_ps(wl); + int jbest = -1; + if (kGroupSize == 8 && (ncluster == 256 || ncluster == 6561)) { + _mm256_store_ps(sx, vx); + uint16_t u = 0; + if (ncluster == 256) { + for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j); + } else { + int s = 1; + for (int j = 0; j < 8; ++j) { u += s*bin3(j, sx[j]); s *= 3; } + } + jbest = u; + } else { + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; + auto idx = add_idx; + for (int j = 0; j < ncluster; j += 8) { + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + } + auto& points = m_in_cluster[jbest]; + auto& values = points.empty() ? m_values : m_c_values[jbest]; + int npoint = values.size()/kGroupSize; + GGML_ASSERT(npoint > 0 && npoint%8 == 0); + int jbest_cluster = jbest; + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + auto best = INFINITY; jbest = -1; + auto idx = add_idx; + for (int j = 0; j < npoint; j += 8) { + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = points.empty() ? jbest : points[jbest]; + } + } else { + __m256 sqx[4]; + const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)); + float sx[8]; + int index[8]; + auto vid_p = _mm256_set1_ps(id); + auto add8 = _mm256_set1_epi32(8); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 4*l; + auto wl = weight + 4*l; + auto vx4 = _mm_loadu_ps(xl); + auto vx = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4)); + auto vw4 = _mm_loadu_ps(wl); + auto vw = _mm256_set_m128(vw4, vw4); + int jbest = -1; + if (ncluster == 256 || ncluster == 625) { + _mm256_storeu_ps(sx, vx); + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k); + } else { + int l = 1; + for (int k = 0; k < 4; ++k) { u += bin5(sx[k])*l; l *= 5; } + } + jbest = u; + } else { + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; + auto idx = add_idx; + for (int j = 0; j < ncluster; j += 8) { + for (int i = 0; i < 4; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i)); + auto vdiff = _mm256_sub_ps(vq, vx); + vdiff = _mm256_and_ps(sign_bit, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + } + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + } + auto& points = m_in_cluster[jbest]; + auto& values = m_c_values[jbest]; + GGML_ASSERT(!points.empty() && points.size()%8 == 0); + int jbest_cluster = jbest; + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; jbest = -1; + auto idx = add_idx; + for (int j = 0; j < int(points.size()); j += 8) { + for (int i = 0; i < 4; ++i) { + auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = points[jbest]; + } + } +#else + // TODO + std::memset(best_idx, 0, kNg*sizeof(int)); +#endif +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours, + const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) { + int ncluster = clusters.size()/kGroupSize; + std::vector<std::vector<int>> p_in_cluster(ncluster); + std::vector<int> which_cluster(num_neighbours*kNumVal); + std::vector<int> ibest(num_neighbours); + std::vector<float> best(num_neighbours); + for (int ip = 0; ip < kNumVal; ++ip) { + auto vp = values.data() + ip*kGroupSize; + for (int j = 0; j < num_neighbours; ++j) { + best[j] = INFINITY; ibest[j] = -1; + } + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = clusters.data() + ic*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + for (int j = 0; j < num_neighbours; ++j) { + if (dist2 < best[j]) { + for (int k = num_neighbours-1; k > j; --k) { + best[k] = best[k-1]; ibest[k] = ibest[k-1]; + } + best[j] = dist2; ibest[j] = ic; + break; + } + } + } + for (int j = 0; j < num_neighbours; ++j) { + if (ibest[j] < 0) { + printf("Oops: ibest[%d] = %d\n", j, ibest[j]); + } + GGML_ASSERT(ibest[j] >= 0); + p_in_cluster[ibest[j]].push_back(ip); + } + std::memcpy(which_cluster.data() + num_neighbours*ip, ibest.data(), num_neighbours*sizeof(int)); + } + std::vector<std::pair<float, int>> extra; + extra.reserve(kNumVal); + for (int ic = 0; ic < ncluster; ++ic) { + auto& points = p_in_cluster[ic]; + if (!points.empty() && points.size()%8 == 0) continue; + extra.clear(); + auto vc = clusters.data() + ic*kGroupSize; + for (int ip = 0; ip < kNumVal; ++ip) { + bool can_add = true; + for (int j = 0; j < num_neighbours; ++j) { + if (which_cluster[num_neighbours*ip+j] == ic) { can_add = false; break; } + } + if (!can_add) continue; + auto vp = values.data() + ip*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + extra.push_back(std::make_pair(dist2, ip)); + } + std::sort(extra.begin(), extra.end()); + int nadd = 8*((points.size()+7)/8) - points.size(); + for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second); + GGML_ASSERT(points.size()%8 == 0); + } + auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size(); + for (auto& points : p_in_cluster) { + min = std::min(min, points.size()); + max = std::max(max, points.size()); + } + c_values.resize(p_in_cluster.size()); + for (int i = 0; i < int(p_in_cluster.size()); ++i) { + auto& points = p_in_cluster[i]; + c_values[i].resize(points.size()*kGroupSize); + auto ptr = c_values[i].data(); + for (auto j : points) { + std::memcpy(ptr, values.data() + j*kGroupSize, kGroupSize*sizeof(float)); + ptr += kGroupSize; + } + } + + if (kVerbose) { + printf("%s: prepared %d clusters\n", __func__, ncluster); + printf(" min number of points in a cluster: %d\n", int(min)); + printf(" max number of points in a cluster: %d\n", int(max)); + } + return p_in_cluster; +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) { + constexpr int ndim = kGroupSize; + GGML_ASSERT(points.size() % ndim == 0); + int npoint = points.size() / ndim; + GGML_ASSERT(npoint >= 2*ncluster); + std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY)); + double Fo = 0; + for (int i = 0; i < npoint; ++i) { + auto v = points.data() + i*ndim; + for (int k = 0; k < ndim; ++k) { + Fo += v[k]*v[k]; + range[k].first = std::min(range[k].first, v[k]); + range[k].second = std::max(range[k].second, v[k]); + } + } + if (kVerbose) printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size()); + if constexpr (is_abs) { + std::vector<int> P(npoint); + for (int idim = 0; idim < ndim; ++idim) { + for (int ip = 0; ip < npoint; ++ip) P[ip] = points[ip*ndim+idim]; + std::sort(P.begin(), P.end()); + if (ndim == 8 && ncluster == 6561) { + mid[2*idim + 0] = P[npoint/3]; + mid[2*idim + 1] = P[2*npoint/3]; + } else { + mid[idim] = npoint%2 == 0 ? 0.5f*(P[npoint/2] + P[npoint/2-1]) : P[npoint/2]; + if (kVerbose) printf("%s: mid[%d] = %g\n", __func__, idim, mid[idim]); + } + } + } else { + for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second); + } + std::vector<float> sump(ncluster*ndim); + std::vector<int> counts(ncluster); + std::vector<float> result(ncluster*ndim); + if (ndim == 8 && (ncluster == 256 || ncluster == 6561)) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k); + } else { + int s = 1; + for (int k = 0; k < ndim; ++k) { + int bin = vp[k] < mid[2*k+0] ? 0 : vp[k] < mid[2*k+1] ? 1 : 2; + u += s*bin; s *= 3; + } + } + ++counts[u]; + for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k]; + } + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) { + printf("%s: Oops. Cluster %d has no points\n", __func__, ic); + GGML_ABORT("fatal error"); + } + for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic]; + } + return result; + } + else if (ndim == 4 && (ncluster == 256 || ncluster == 625)) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k); + } else { + int s = 1; + for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; } + } + if (u >= int(counts.size())) { + printf("Oops: u = %u, vp = %g, %g, %g, %g\n", u, vp[0], vp[1], vp[2], vp[3]); + u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) { + auto bin = bin4(vp[k]); u |= (bin << 2*k); + printf(" bin[%d] = %d, u = %u", k, bin, u); + } + } else { + for (int k = 0; k < ndim; ++k) printf(" bin[%d] = %d", k, bin5(vp[k])); + } + printf("\n"); + GGML_ABORT("fatal error"); + } + ++counts[u]; + for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k]; + } + int nzero = 0; + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) { + ++nzero; + printf("%s: Oops. Cluster %d has no points: ", __func__, ic); + for (int k = 0; k < ndim; ++k) { + int l = (ic >> 2*k) & 3; + printf(" %d", l); + } + printf("\n"); + } else { + for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic]; + } + } + if (nzero > 0) printf("%s: %d out of %d clusters dir not have any points\n", __func__, nzero, ncluster); + return result; + } + std::mt19937 rndm(1234); + float scale = 1.f/4294967296.f; + for (int i = 0; i < ncluster; ++i) { + auto v = result.data() + i*ndim; + for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm(); + } + std::vector<int> which_cluster(npoint, -1); + double Flast = Fo; + for (int iter = 0; iter < niter; ++iter) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + int nchanged = 0; + double F = 0; + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + float best = INFINITY; int ibest = -1; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = result.data() + ndim*ic; + float dist2 = 0; + for (int k = 0; k < ndim; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best) { + best = dist2; ibest = ic; + } + } + if (ibest < 0) { + printf("Oops(iteration %d) - failed to find cluster for point", iter); + for (int k = 0; k < ndim; ++k) printf(" %g", vp[k]); + printf("\nHave %d clusters\n", ncluster); + } + GGML_ASSERT(ibest >= 0); + F += best; + if (which_cluster[ip] != ibest) ++nchanged; + which_cluster[ip] = ibest; + ++counts[ibest]; + auto vc = sump.data() + ndim*ibest; + for (int k = 0; k < ndim; ++k) vc[k] += vp[k]; + } + if (nchanged == 0) break; + for (int ic = 0; ic < ncluster; ++ic) { + float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f; + auto vc = sump.data() + ndim*ic; + auto r = result.data() + ndim*ic; + for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm; + } + if (kVerbose) printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged); + if (iter > 1 && Flast/F - 1 < 1e-6) break; + Flast = F; + } + int nzero = 0; + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) ++nzero; + } + if (nzero > 0) printf("%s: there are %d empty clusters\n", __func__, nzero); + return result; +} + +// ========================================== iq2_kt ==================================================== + +using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>; + +const QuantizerIQ2KT& iq2kt_quantizer() { + static std::mutex mutex; + static std::unique_ptr<QuantizerIQ2KT> quantizer; + std::lock_guard<std::mutex> lock(mutex); + if (!quantizer) quantizer = std::make_unique<QuantizerIQ2KT>(256, 8); + return *quantizer; +} + +void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, + float * qtmp) { + + constexpr float kSigmaScale = 2.0f; + using Q = QuantizerIQ2KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq2_kt * y = (block_iq2_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq2kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + amax = std::max(amax, ax); + } + quantizer.find_best_match( amax/96.f, xb, weight, best_idx); + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); + + auto idx = best_idx; + if (score_p > score_m) scales[ib] = dp; + else { + scales[ib] = dm; idx += Q::kNg; + } + auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q = quantizer.values() + idx[ig]*Q::kGroupSize; + for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j]; + qt += Q::kGroupSize; + } + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + if (!max_scale) { + *dptr = 0; + return; + } + + float d = max_scale/iq4k_values[0]; + float best = 0; + for (int itry = -9; itry <= 9; ++itry) { + float id = (itry + iq4k_values[0])/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * qb = qtmp + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); + float dl = iq4k_values[ls]; + for (int j = 0; j < Q::kBlockSize; ++j) { + float q = dl*qb[j]; + sumqx += wb[j]*xb[j]*q; + sumq2 += wb[j]*q*q; + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + qb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]); + int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + + *dptr = d; + if (!d) return; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + auto qs = (uint16_t *)y[ibl].ql; + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + int ls = iq4k_values[(y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf]; + float dl = d*ls; + quantizer.find_best_match(dl, xb, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + qs[j] = best_idx[j]; + auto xl = xb + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + qs += Q::kNg; + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + if (!d) return; + } else { + break; + } + + if (false) { + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + auto qs = (uint16_t *)y[ibl].ql; + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j]; + auto pair = quantizer.find_best_scale(xb, weight, best_idx); + scales[ib] = pair.first; + } + } + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]); + int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + } + + } + +} +} + +void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_kt * y = (block_iq2_kt *)vy; + quantize_row_iq2_kt_ref(x, y, k); +} + +size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize); + std::vector<float> weights(n_per_row); + std::vector<float> xtmp(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) { + assert(k % QuantizerIQ2KT::kSuperBlockSize == 0); + const int nb = k / QuantizerIQ2KT::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * QuantizerIQ2KT::kScale; + x = (const block_iq2_kt *)(dptr + 1); + auto& deq = iq2kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto yl = y + ibl*QuantizerIQ2KT::kSuperBlockSize; + auto yh = yl + QuantizerIQ2KT::kSuperBlockSize/2; + const uint16_t * ql = (const uint16_t *)x[ibl].ql; + const uint16_t * qh = ql + QuantizerIQ2KT::kNg*QuantizerIQ2KT::kNblock/2; + for (int ib = 0; ib < QuantizerIQ2KT::kNblock/2; ++ib) { + float sl = d * iq4k_values[x[ibl].scales[ib] & 0xf]; + float sh = d * iq4k_values[x[ibl].scales[ib] >> 4]; + for (int ig = 0; ig < QuantizerIQ2KT::kNg; ++ig) { + deq.set_values(ql[ig], yl, sl); + deq.set_values(qh[ig], yh, sh); + yl += QuantizerIQ2KT::kGroupSize; + yh += QuantizerIQ2KT::kGroupSize; + } + ql += QuantizerIQ2KT::kNg; + qh += QuantizerIQ2KT::kNg; + } + } +} + +void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} + +namespace { + +using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>; +const QuantizerIQ3KT& iq3kt_quantizer() { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static std::unique_ptr<QuantizerIQ3KT> quantizer; + if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 8); + return *quantizer; +} + +void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, + float * all_weights, float * qtmp) { + + constexpr float kSigmaScale = 2.0f; + constexpr float kStep = 8.0f; + + using Q = QuantizerIQ3KT; + + static_assert(Q::kNumVal%8 == 0); + + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + + float * dptr = (float *)vy; + + block_iq3_kt * y = (block_iq3_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq3kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j])); + if (!amax_row) { + *dptr = 0.f; + std::memset(y, 0, nblock*sizeof(block_iq3_kt)); + return; + } + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_scale = 0, max_scale = 0; + + float xaux[Q::kBlockSize]; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_kt)); + + auto scales = all_scales + ibl*Q::kNblock; + auto xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + xaux[j] = ax; + amax = std::max(amax, ax); + } + scales[ib] = 0; + if (!amax) continue; + + //quantizer.find_best_match(amax/96.f, xaux, weight, best_idx+Q::kNg); + //scales[ib] = quantizer.find_best_scale(xaux, weight, best_idx+Q::kNg).first; + + float scale_0 = std::max(84.f, 123.f*amax/amax_row); + //float scale_0 = std::max(64.f, 123.f*amax/amax_row); + float best = 0; + for (int itry = -3; itry <= 3; ++itry) { + quantizer.find_best_match(amax/(scale_0 + kStep*itry), xaux, weight, best_idx); + auto [d, score] = quantizer.find_best_scale(xaux, weight, best_idx); + if (score > best) { + best = score; + scales[ib] = d; + std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int)); + } + } + + auto xt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q = quantizer.values() + Q::kGroupSize*best_idx[Q::kNg+ig]; + for (int j = 0; j < Q::kGroupSize; ++j) *xt++ = q[j]; + } + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + GGML_ASSERT(max_scale >= 0); + float d = max_scale/15; + float best = 0; + for (int itry = -9; itry <= 9; ++itry) { + float id = (itry*0.2f + 15)/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * qb = qtmp + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = nearest_int(id*scales[ib]); + ls = std::max(0, std::min(15, ls)); + float dl = ls; + for (int j = 0; j < Q::kBlockSize; ++j) { + float q = dl*qb[j]; + sumqx += wb[j]*std::abs(xb[j])*q; + sumq2 += wb[j]*q*q; + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + qb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = nearest_int(id*scales[ib]); + int ls2 = nearest_int(id*scales[ib + Q::kNblock/2]); + ls1 = std::max(0, std::min(15, ls1)); + ls2 = std::max(0, std::min(15, ls2)); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + + *dptr = d; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + uint16_t * ql = (uint16_t *)y[ibl].ql; + + std::memset(y[ibl].qh, 0, kNumGroups/2); + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kBlockSize; ++j) { + xaux[j] = std::abs(xb[j]); + if (xb[j] < 0) y[ibl].qh[j] |= (1 << ib); + } + int ls = (y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf; + float dl = d*ls; + quantizer.find_best_match(dl, xaux, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + ql[ib*Q::kNg+j] = best_idx[j]; + auto xl = xaux + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + if (!d) break; + } else { + break; + } + } +} +} + +void quantize_row_iq3_kt_ref(const float * x, block_iq3_kt * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq3_kt(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_kt * y = (block_iq3_kt *)vy; + quantize_row_iq3_kt_ref(x, y, k); +} + +size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ3KT::kBlockSize); + std::vector<float> weights(n_per_row), xtmp(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq3_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) { + using Q = QuantizerIQ3KT; + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + assert(k % Q::kSuperBlockSize == 0); + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * Q::kScale; + x = (const block_iq3_kt *)(dptr + 1); + auto& deq = iq3kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto yl = y + ibl*Q::kSuperBlockSize; + auto yh = yl + Q::kSuperBlockSize/2; + auto qll = (const uint16_t *)x[ibl].ql; + auto qlh = qll + kNumGroups/2; + int jj = 0; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + float sl = d * (x[ibl].scales[ib] & 0xf); + float sh = d * (x[ibl].scales[ib] >> 4); + uint8_t l_mask = 1 << ib; + uint8_t h_mask = l_mask << (Q::kNblock/2); + for (int ig = 0; ig < Q::kNg; ++ig) { + deq.set_values(qll[jj], yl, sl); + deq.set_values(qlh[jj], yh, sh); + for (int j = 0; j < Q::kGroupSize; ++j) { + if (x[ibl].qh[ig*Q::kGroupSize+j] & l_mask) yl[j] = -yl[j]; + if (x[ibl].qh[ig*Q::kGroupSize+j] & h_mask) yh[j] = -yh[j]; + } + yl += Q::kGroupSize; + yh += Q::kGroupSize; + ++jj; + } + } + } +} + +void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} + +// ======================================== iq4_kt + +namespace{ + +using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>; + +const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static std::unique_ptr<QuantizerIQ4KT> quantizer1; + static std::unique_ptr<QuantizerIQ4KT> quantizer2; + if (with_offset) { + if (!quantizer2) quantizer2 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096+32768); + return *quantizer2; + } + if (!quantizer1) quantizer1 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096); + return *quantizer1; +} + +void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { + + constexpr float kSigmaScale = 2.0f; + constexpr int kNtry = 2; + using Q = QuantizerIQ4KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq4_kt * y = (block_iq4_kt *)(dptr + 2); + + auto& quantizer1 = iq4kt_quantizer(); + auto& quantizer2 = iq4kt_quantizer(true); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_row = 0, row_av = 0; + for (int j = 0; j < n_per_row; ++j) { + row_av += x[j]; + amax_row = std::max(amax_row, std::abs(x[j])); + } + row_av /= n_per_row; + dptr[1] = row_av; + if (!amax_row) { + dptr[0] = 0.f; + std::memset(y, 0, nblock*sizeof(block_iq4_kt)); + return; + } + + int best_idx[2*Q::kNg]; + float xaux[Q::kBlockSize]; + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq4_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + float ax = std::abs(xaux[j]); + amax = std::max(amax, ax); + } + if (!amax) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale_0 = std::max(92.f, 127.f*amax/amax_row); + for (int itry = -kNtry; itry <= kNtry; ++itry) { + quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx); + if (score_p > best) { + best = score_p; scales[ib] = dp; + } + quantizer1.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dm, score_m] = quantizer1.find_best_scale(xaux, weight, best_idx); + if (score_m > best) { + best = score_m; scales[ib] = dm; + } + } + + quantizer2.find_best_match(scales[ib], xaux, weight, best_idx); + auto [d, score] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score > best) { + scales[ib] = d; + y[ibl].qs[ib] = 1; + } + bool with_offset = false; + for (int itry = -kNtry; itry <= kNtry; ++itry) { + quantizer2.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dp, score_p] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score_p > best) { + best = score_p; scales[ib] = dp; with_offset = true; + } + quantizer2.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dm, score_m] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score_m > best) { + best = score_m; scales[ib] = dm; with_offset = true; + } + } + if (with_offset) y[ibl].qs[ib] = 1; + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + float d = -max_scale/64; + + dptr[0] = d; + if (!d) return; + + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + + for (int iloop = 0; iloop < 1; ++iloop) { + + const float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + // high 3 bits + scales + // each block of 32 needs 8 x 3 (high bits) + 1 x 8 (scale) = 32 bits = 1 x uint32_t + // we have 8 blocks + auto shb = y[ibl].qs; // high 3 bits + scales + auto ql = (uint8_t *)(shb + Q::kNblock); + auto qh = ql + kNumGroups; + std::memset(qh, 0, kNumGroups/2); + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + int ls = nearest_int(id*scales[ib]); + ls = std::min(ls, 63); + *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1); + float dl = d*ls; + quantizer.find_best_match(dl, xaux, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + shb[ib] |= ((best_idx[j] >> 12) << (8 + 3*j)); + ql[Q::kNg*ib + j] = best_idx[j] & 255; + qh[(Q::kNg*ib + j)%(kNumGroups/2)] |= ((best_idx[j] >> 8) & 0xf) << 4*((Q::kNg*ib + j)/(kNumGroups/2)); + auto xl = xaux + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + Q::kGroupSize*best_idx[j]; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + dptr[0] = d; + if (!d) break; + } else { + break; + } + } +} +} + +void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_kt * y = (block_iq4_kt *)vy; + quantize_row_iq4_kt_ref(x, y, k); +} + +size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ4KT::kBlockSize); + std::vector<float> weights(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq4_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { + using Q = QuantizerIQ4KT; + assert(k % Q::kSuperBlockSize == 0); + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = dptr[0] * Q::kScale; + const float row_av = dptr[1]; + x = (const block_iq4_kt *)(dptr + 2); + auto& deq = iq4kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto shb = x[ibl].qs; + auto ql = (const uint8_t *)(shb + Q::kNblock); + auto qh = ql + kNumGroups; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int offset = shb[ib] & 1 ? 32768 + 4096 : 4096; + //auto& deq = shb[ib] & 1 ? deq2 : deq1; + int ls = int((shb[ib] & 0xff) >> 1) - 64; + float sl = d * ls; + for (int ig = 0; ig < Q::kNg; ++ig) { + int jj = ib*Q::kNg+ig; + uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12); + deq.set_values(idx, y, sl, offset); + for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av; + y += Q::kGroupSize; + } + } + } +} + +void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} |