summaryrefslogtreecommitdiff
path: root/ggml-quants.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-quants.c')
-rw-r--r--ggml-quants.c630
1 files changed, 630 insertions, 0 deletions
diff --git a/ggml-quants.c b/ggml-quants.c
index 7d2f033e..ac061b63 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -3441,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
+static const uint32_t iq3xxs_grid[256] = {
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
+
static const uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3507,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
}
}
+// ====================== 3.0625 bpw (de)-quantization
+
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ uint32_t aux32;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * scales_and_signs = qs + QK_K/4;
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ qs += 8;
+ }
+ }
+}
+
//===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8551,6 +8618,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
#endif
}
+// TODO
+void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_iq3_xxs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ ggml_int8x16x4_t q3s;
+ ggml_int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ float sumf1 = 0, sumf2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
+ const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+ const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+ const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+ const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+ q3 += 16;
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
+ }
+ sumf += d*(sumf1 + sumf2);
+ }
+ *s = 0.5f * sumf;
+
+#elif defined(__AVX2__)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ memcpy(aux32, gas, 8); gas += 8;
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[0] >> 28;
+ const uint16_t ls2 = aux32[1] >> 28;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+ uint32_t aux32;
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 4; ++j) {
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ q3 += 8;
+ bsum += sumi * ls;
+ }
+ sumf += d * bsum;
+ }
+ *s = 0.25f * sumf;
+#endif
+}
+
// ================================ IQ2 quantization =============================================
typedef struct {
@@ -9189,3 +9386,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
return nrow * nblock * sizeof(block_iq2_xs);
}
+//
+// ============================================= 3-bit using D4 lattice
+//
+
+typedef struct {
+ uint32_t * grid;
+ int * map;
+ uint16_t * neighbours;
+} iq3_entry_t;
+
+static iq3_entry_t iq3_data[1] = {
+ {NULL, NULL, NULL},
+};
+
+static inline int iq3_data_index(int grid_size) {
+ (void)grid_size;
+ GGML_ASSERT(grid_size == 256);
+ return 0;
+}
+
+static int iq3_compare_func(const void * left, const void * right) {
+ const int * l = (const int *)left;
+ const int * r = (const int *)right;
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq3xs_init_impl(int grid_size) {
+ const int gindex = iq3_data_index(grid_size);
+ if (iq3_data[gindex].grid) {
+ return;
+ }
+ static const uint16_t kgrid_256[256] = {
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
+ };
+ const int kmap_size = 4096;
+ const int nwant = 2;
+ const uint16_t * kgrid = kgrid_256;
+ uint32_t * kgrid_q3xs;
+ int * kmap_q3xs;
+ uint16_t * kneighbors_q3xs;
+
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
+ for (int k = 0; k < grid_size; ++k) {
+ int8_t * pos = (int8_t *)(the_grid + k);
+ for (int i = 0; i < 4; ++i) {
+ int l = (kgrid[k] >> 3*i) & 0x7;
+ pos[i] = 2*l + 1;
+ }
+ }
+ kgrid_q3xs = the_grid;
+ iq3_data[gindex].grid = the_grid;
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
+ iq3_data[gindex].map = kmap_q3xs;
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
+ uint32_t aux32;
+ uint8_t * aux8 = (uint8_t *)&aux32;
+ for (int i = 0; i < grid_size; ++i) {
+ aux32 = kgrid_q3xs[i];
+ uint16_t index = 0;
+ for (int k=0; k<4; ++k) {
+ uint16_t q = (aux8[k] - 1)/2;
+ index |= (q << 3*k);
+ }
+ kmap_q3xs[index] = i;
+ }
+ int8_t pos[4];
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+ int num_neighbors = 0, num_not_in_map = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q3xs[i] >= 0) continue;
+ ++num_not_in_map;
+ for (int k = 0; k < 4; ++k) {
+ int l = (i >> 3*k) & 0x7;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+ int n = 0; int d2 = dist2[0];
+ int nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ ++n;
+ }
+ num_neighbors += n;
+ }
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
+ int counter = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q3xs[i] >= 0) continue;
+ for (int k = 0; k < 4; ++k) {
+ int l = (i >> 3*k) & 0x7;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+ kmap_q3xs[i] = -(counter + 1);
+ int d2 = dist2[0];
+ uint16_t * start = &kneighbors_q3xs[counter++];
+ int n = 0, nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
+ ++n;
+ }
+ *start = n;
+ }
+ free(dist2);
+}
+
+void iq3xs_free_impl(int grid_size) {
+ GGML_ASSERT(grid_size == 256);
+ const int gindex = iq3_data_index(grid_size);
+ if (iq3_data[gindex].grid) {
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
+ }
+}
+
+static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+ int num_neighbors = neighbours[0];
+ GGML_ASSERT(num_neighbors > 0);
+ float best_d2 = FLT_MAX;
+ int grid_index = -1;
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float d2 = 0;
+ for (int i = 0; i < 4; ++i) {
+ float q = pg[i];
+ float diff = scale*q - xval[i];
+ d2 += weight[i]*diff*diff;
+ }
+ if (d2 < best_d2) {
+ best_d2 = d2; grid_index = neighbours[j];
+ }
+ }
+ GGML_ASSERT(grid_index >= 0);
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
+ return grid_index;
+}
+
+static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+ const int gindex = iq3_data_index(256);
+
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
+ const int * kmap_q3xs = iq3_data[gindex].map;
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 8;
+
+ const int nbl = n/256;
+
+ block_iq3_xxs * y = vy;
+
+ float scales[QK_K/32];
+ float weight[32];
+ float xval[32];
+ int8_t L[32];
+ int8_t Laux[32];
+ float waux[32];
+ bool is_on_grid[8];
+ bool is_on_grid_aux[8];
+ uint8_t block_signs[8];
+ uint8_t q3[3*(QK_K/8)];
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+ memset(q3, 0, 3*QK_K/8);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const float * xb = xbl + 32*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+ }
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < 4; ++k) {
+ int nflip = 0;
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+ }
+ }
+ if (nflip%2) {
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+ for (int i = 1; i < 8; ++i) {
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+ if (ax < min) {
+ min = ax; imin = i;
+ }
+ }
+ xval[8*k+imin] = -xval[8*k+imin];
+ s ^= (1 << imin);
+ }
+ block_signs[k] = s & 127;
+ }
+ float max = xval[0];
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+ if (!max) {
+ scales[ib] = 0;
+ memset(L, 0, 32);
+ continue;
+ }
+ float best = 0;
+ float scale = max/(2*kMaxQ-1);
+ for (int is = -15; is <= 15; ++is) {
+ float id = (2*kMaxQ-1+is*0.2f)/max;
+ float this_scale = 1/id;
+ for (int k = 0; k < 8; ++k) {
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ is_on_grid_aux[k] = true;
+ if (grid_index < 0) {
+ is_on_grid_aux[k] = false;
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
+ }
+ }
+ int n_not_ongrid = 0;
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+ if (n_not_ongrid > 0 && scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < 8; ++k) {
+ if (is_on_grid[k]) continue;
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 3*i);
+ }
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+ }
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+ // and correspondingly flip quant signs.
+ scale = -scale;
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+ }
+ for (int k = 0; k < 8; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ q3[8*ib+k] = grid_index;
+ }
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ memset(y[ibl].qs, 0, 3*QK_K/8);
+ continue;
+ }
+
+ float d = max_scale/31;
+ y[ibl].d = GGML_FP32_TO_FP16(d);
+ float id = 1/d;
+ float sumqx = 0, sumq2 = 0;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(15, l));
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
+ if (false) {
+ const float * xb = xbl + 32*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+ }
+ const float db = 0.25f * d * (1 + 2*l);
+ for (int k = 0; k < 8; ++k) {
+ const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
+ const float * xk = xb + 4*k;
+ const float * wk = weight + 4*k;
+ //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
+ const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
+ float best_mse = 0; int best_index = q3[8*ib+k];
+ for (int j = 0; j < 4; ++j) {
+ float diff = db * grid[j] * signs[j] - xk[j];
+ best_mse += wk[j] * diff * diff;
+ }
+ for (int idx = 0; idx < 256; ++idx) {
+ //grid = (const uint8_t *)(kgrid_q3xs + idx);
+ grid = (const uint8_t *)(iq3xxs_grid + idx);
+ float mse = 0;
+ for (int j = 0; j < 4; ++j) {
+ float diff = db * grid[j] * signs[j] - xk[j];
+ mse += wk[j] * diff * diff;
+ }
+ if (mse < best_mse) {
+ best_mse = mse; best_index = idx;
+ }
+ }
+ q3[8*ib+k] = best_index;
+ //grid = (const uint8_t *)(kgrid_q3xs + best_index);
+ grid = (const uint8_t *)(iq3xxs_grid + best_index);
+ for (int j = 0; j < 4; ++j) {
+ float q = db * grid[j] * signs[j];
+ sumqx += wk[j] * q * xk[j];
+ sumq2 += wk[j] * q * q;
+ }
+ }
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
+ }
+ }
+ memcpy(y[ibl].qs, q3, 3*QK_K/8);
+ }
+}
+
+size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+ (void)hist;
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int row = 0; row < nrow; ++row) {
+ quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq3_xxs);
+ }
+ return nrow * nblock * sizeof(block_iq3_xxs);
+}
+
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
+ assert(k % QK_K == 0);
+ block_iq3_xxs * restrict y = vy;
+ quantize_row_iq3_xxs_reference(x, y, k);
+}
+
+void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
+ assert(k % QK_K == 0);
+ quantize_row_iq3_xxs_impl(x, y, k, NULL);
+}