summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/include/ggml.h6
-rw-r--r--ggml/src/ggml-common.h12
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu18
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu3
-rw-r--r--ggml/src/ggml.c21
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp347
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
9 files changed, 414 insertions, 10 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 144e87f5..a0bcc67f 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -393,7 +393,8 @@ extern "C" {
GGML_TYPE_IQ3_K = 38,
GGML_TYPE_IQ4_K = 39,
GGML_TYPE_IQ5_K = 40,
- GGML_TYPE_IQ2_TN = 41,
+ GGML_TYPE_IQ6_K = 41,
+ GGML_TYPE_IQ2_TN = 42,
GGML_TYPE_COUNT,
};
@@ -444,7 +445,8 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 5847d903..2fbac06a 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -142,6 +142,9 @@ typedef sycl::half2 ggml_half2;
#define QI5_XS (QK_K / (4*QR5_XS))
#define QR5_XS 2
+#define QI6_XS (QK_K / (4*QR6_XS))
+#define QR6_XS 2
+
#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 4
@@ -493,6 +496,15 @@ typedef struct {
} block_iq5_k;
static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding");
+typedef struct {
+ ggml_half d;
+ uint16_t extra;
+ int8_t scales[QK_K/16];
+ uint8_t qs[QK_K/2];
+ uint8_t qh[QK_K/4];
+} block_iq6_k;
+static_assert(sizeof(block_iq6_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/4 + QK_K/16, "wrong iq6_k block size/padding");
+
#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index c18e865a..07a53bcd 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -705,6 +705,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR6_XS;
+ static constexpr int qi = QI6_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 8def1547..ae5e6a3c 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -251,6 +251,16 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
}
+#define VDR_IQ6_K_Q8_1_MMVQ 4
+#define VDR_IQ6_K_Q8_1_MMQ 4
+
+// TODO
+__device__ __forceinline__ float vec_dot_iq6_k_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ return 0;
+}
+
static const __device__ uint32_t iq2k_table[512] = {
0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311,
0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111,
@@ -534,10 +544,16 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+void mul_mat_vec_iq6_k_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
-
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 3dc5f41c..7af8e570 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -16,6 +16,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+void mul_mat_vec_iq6_k_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+
void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 428d822f..9eb3fa4f 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -447,6 +447,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ5_K:
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ6_K:
+ mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 5c817030..b5fdb96d 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1040,6 +1040,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
+ [GGML_TYPE_IQ6_K] = {
+ .type_name = "iq6_k",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq6_k),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq6_k,
+ .from_float = quantize_row_iq6_k,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq6_k_ref,
+ .vec_dot = vec_dot_iq6_k_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
};
// For internal test use
@@ -3394,6 +3406,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break;
case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break;
+ case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
@@ -9648,6 +9661,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -10033,6 +10047,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -10168,6 +10183,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13092,6 +13108,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13287,6 +13304,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -13556,6 +13574,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
@@ -14152,6 +14171,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K:
@@ -20892,6 +20912,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 1cba1532..f12367a4 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -1516,6 +1516,346 @@ size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
}
//
+// ============================================== iq6_K
+//
+void dequantize_row_iq6_k(const block_iq6_k * x, float * y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const int8_t * sl = x[i].scales;
+
+ uint16_t extra = x[i].extra;
+
+ int shift = 0;
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+
+ float dl1 = d * sl[4*ib64 + 0];
+ float dl2 = d * sl[4*ib64 + 1];
+ float dl3 = d * sl[4*ib64 + 2];
+ float dl4 = d * sl[4*ib64 + 3];
+ int m1 = extra & 1 ? -127 : -125;
+ int m2 = extra & 2 ? -127 : -125;
+ int m3 = extra & 4 ? -127 : -125;
+ int m4 = extra & 8 ? -127 : -125;
+ for (int j = 0; j < 16; ++j) {
+ y[j+ 0] = dl1 * ((((qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 0x03) << 4)) << 2) + m1);
+ y[j+16] = dl2 * ((((qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 0x03) << 4)) << 2) + m2);
+ y[j+32] = dl3 * ((((qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 0x0c) << 2)) << 2) + m3);
+ y[j+48] = dl4 * ((((qs[j+16] >> 4) | (((qh[j+16] >> shift) & 0x0c) << 2)) << 2) + m4);
+ }
+ y += 64;
+ qs += 32;
+ extra >>= 4;
+ shift += 4;
+ if (shift == 8) { qh += 32; shift = 0; }
+ }
+
+ }
+}
+
+void vec_dot_iq6_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ6_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+
+ // TODO
+ //const int nb = n / QK_K;
+
+ //const block_iq5_k * x = (const block_iq5_k *)vx;
+ //const block_q8_K * y = (const block_q8_K *)vy;
+
+ //float sumf = 0;
+
+ //for (int i = 0; i < nb; i++) {
+
+ // const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ // const uint8_t * qs = x[i].qs;
+ // const uint8_t * qh = x[i].qh;
+ // const uint8_t * sl = x[i].scales_l;
+ // const uint8_t * sh = x[i].scales_h;
+ // const int8_t * q8 = y[i].qs;
+
+ // uint16_t extra = x[i].extra;
+
+ // int shift = 0;
+ // int sumb = 0;
+ // for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+
+ // int dl1 = (((sl[2*ib64+0] & 0xf) | ((sh[ib64] << 4) & 0x30)) - 32);
+ // int dl2 = (((sl[2*ib64+0] >> 4) | ((sh[ib64] << 2) & 0x30)) - 32);
+ // int dl3 = (((sl[2*ib64+1] & 0xf) | ((sh[ib64] >> 0) & 0x30)) - 32);
+ // int dl4 = (((sl[2*ib64+1] >> 4) | ((sh[ib64] >> 2) & 0x30)) - 32);
+ // const int8_t * values1 = iq5nl_values + ((extra & 1) << 5);
+ // const int8_t * values2 = iq5nl_values + ((extra & 2) << 4);
+ // const int8_t * values3 = iq5nl_values + ((extra & 4) << 3);
+ // const int8_t * values4 = iq5nl_values + ((extra & 8) << 2);
+ // int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+ // for (int j = 0; j < 16; ++j) {
+ // sumi1 += q8[j+ 0] * values1[(qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4)];
+ // sumi2 += q8[j+16] * values2[(qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4)];
+ // sumi3 += q8[j+32] * values3[(qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3)];
+ // sumi4 += q8[j+48] * values4[(qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3)];
+ // }
+ // sumb += dl1 * sumi1 + dl2 * sumi2 + dl3 * sumi3 + dl4 * sumi4;
+ // q8 += 64;
+ // qs += 32;
+ // extra >>= 4;
+ // shift += 2;
+ // }
+ // sumf += d * sumb;
+
+ //}
+
+ //*s = sumf;
+
+}
+
+namespace {
+
+void quantize_row_iq6_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
+ const int ntry = 5;
+ const float step = 1.f;
+
+ block_iq6_k * y = (block_iq6_k *)vy;
+
+ float scales[QK_K/16];
+ float weight[16];
+
+ uint8_t L[QK_K];
+
+ //int nerr = 0;
+
+ for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq6_k));
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+ const float * xbl = x + ibl*QK_K;
+ float sumx2 = 0;
+ for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j];
+ const float sigma2 = 2*sumx2/QK_K;
+
+ float max_abs_scale = 0;
+ uint16_t extra = 0;
+
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ const float * xb = xbl + 16*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*QK_K + ib*16;
+ for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < 16; ++j) weight[j] = 1.f; //0.25f*sigma2 + xb[j]*xb[j];
+ }
+ float amax = 0;
+ for (int j = 0; j < 16; ++j) {
+ float ax = fabsf(xb[j]);
+ amax = std::max(ax, amax);
+ }
+ if (!amax) {
+ scales[ib] = 0;
+ continue;
+ }
+ float d = amax/127;
+ float id = 0.25f/d;
+ float sumqx_p = 0, sumq2_p = 0;
+ float sumqx_m = 0, sumq2_m = 0;
+ for (int j = 0; j < 16; ++j) {
+ float w = weight[j];
+ int lp = nearest_int(id*xb[j] + 31.25f);
+ lp = std::max(0, std::min(63, lp));
+ float qp = 4*lp - 125;
+ sumqx_p += w*qp*xb[j];
+ sumq2_p += w*qp*qp;
+ int lm = nearest_int(id*xb[j] + 31.75f);
+ lm = std::max(0, std::min(63, lm));
+ float qm = 4*lm - 127;
+ sumqx_m += w*qm*xb[j];
+ sumq2_m += w*qm*qm;
+ //printf("x = %g, lp = %d, qp = %g -> %g, lm = %d, qm = %g -> %g\n", xb[j], lp, qp, d*qp, lm, qm, d*qm);
+ }
+ d = sumqx_p/sumq2_p;
+ float best = d*sumqx_p;
+ bool is_shifted = false;
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d*sumqx_m; is_shifted = true;
+ }
+ for (int itry = -ntry; itry <= ntry; ++itry) {
+ //0.25/amax*127 => 31.75/amax
+ id = (itry*step + 31.75f)/amax;
+ sumqx_p = sumq2_p = 0;
+ sumqx_m = sumq2_m = 0;
+ for (int j = 0; j < 16; ++j) {
+ float w = weight[j];
+ int l = nearest_int(id*xb[j] + 31.25f);
+ l = std::max(0, std::min(63, l));
+ float q = 4*l - 125;
+ sumqx_p += w*q*xb[j];
+ sumq2_p += w*q*q;
+ l = nearest_int(id*xb[j] + 31.75f);
+ l = std::max(0, std::min(63, l));
+ q = 4*l - 127;
+ sumqx_m += w*q*xb[j];
+ sumq2_m += w*q*q;
+ }
+ if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
+ d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false;
+ }
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
+ }
+ }
+ scales[ib] = d;
+ if (is_shifted) extra |= (1 << ib);
+
+ max_abs_scale = std::max(max_abs_scale, amax);
+
+ //float mse = 0;
+ //id = 0.25f/d;
+ //float xmin = is_shifted ? 31.75f : 31.25f;
+ //for (int j = 0; j < 16; ++j) {
+ // int l = nearest_int(id*xb[j] + xmin);
+ // l = std::max(0, std::min(63, l));
+ // float diff = xb[j] - 4*d*(l - xmin);
+ // mse += diff*diff;
+ //}
+ //printf("Block %d: %g\n", ib, sqrtf(mse/16));
+
+ }
+
+ if (!max_abs_scale) continue;
+ float d = max_abs_scale/255;
+ y[ibl].d = GGML_FP32_TO_FP16(d);
+ y[ibl].extra = extra;
+
+ float id = 1/d;
+
+ std::memset(L, 0, QK_K);
+ float sumqx = 0, sumq2 = 0;
+ //float tot_mse = 0;
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ int ls = nearest_int(id*scales[ib]);
+ ls = MAX(0, MIN(255, ls));
+ y[ibl].scales[ib] = ls;
+ float dl = d * ls;
+ if (dl) {
+ const float xmin = y[ibl].extra & (1 << ib) ? 31.75f : 31.25f;
+ const float * xb = xbl + 16*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*QK_K + ib*16;
+ for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < 16; ++j) weight[j] = 1.f; //0.25f*sigma2 + xb[j]*xb[j];
+ }
+ float idl = 0.25f/dl;
+ int ib32 = ib/2;
+ int offset = 16*(ib%2);
+ uint8_t * qs = y[ibl].qs + 32*(ib32/2) + offset;
+ uint8_t * qh = y[ibl].qh + 32*(ib32/4) + offset;
+ //float mse1 = 0, mse2 = 0;
+ for (int j = 0; j < 16; ++j) {
+ int l = nearest_int(idl*xb[j] + xmin);
+ l = std::max(0, std::min(63, l));
+ L[16*ib + j] = l;
+ qs[j] |= ((l & 0xf) << 4*(ib32%2));
+ qh[j] |= ((l >> 4) << 2*(ib32%4));
+ float w = weight[j];
+ float q = 4*(l - xmin)*ls;
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ //float diff = xb[j] - 4*d*ls*(l - xmin);
+ //mse1 += diff*diff;
+ int ll = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30);
+ if (ll != l) {
+ printf("Oops: l = %d, ll = %d, qs = %u, qh = %u, ib = %d\n", l, ll, qs[j], qh[j], ib);
+ exit(1);
+ }
+ //diff = xb[j] - 4*d*ls*(ll - xmin);
+ //mse2 += diff*diff;
+ }
+ //printf("Block %d: %g, %g\n", ib, sqrtf(mse1/16), sqrtf(mse2/16));
+ //tot_mse += mse1;
+ }
+ }
+ //printf("=============== rmse = %g, Old scale: %g New scale: %g\n", sqrtf(tot_mse/256), d, sumqx/sumq2);
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2);
+
+ //d = GGML_FP16_TO_FP32(y[ibl].d);
+ //tot_mse = 0;
+ //for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ // const float * xb = xbl + 32*ib32;
+ // float dl1 = d * y[ibl].scales[2*ib32+0];
+ // float dl2 = d * y[ibl].scales[2*ib32+1];
+ // int min1 = y[ibl].extra & (1 << (2*ib32+0)) ? -127 : -125;
+ // int min2 = y[ibl].extra & (1 << (2*ib32+1)) ? -127 : -125;
+ // const uint8_t * qs = y[ibl].qs + 32*(ib32/2);
+ // const uint8_t * qh = y[ibl].qh + 32*(ib32/4);
+ // for (int j = 0; j < 16; ++j) {
+ // int l = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30);
+ // if (l != L[32*ib32 + j]) {
+ // ++nerr;
+ // printf("Oops: %d vs %u for ib32 = %d, j = %d. qs = %u (0x%02x), qh = %u (0x%02x)\n", l, L[32*ib32 + j], ib32, j, qs[j], qs[j], qh[j], qh[j]);
+ // if (nerr > 10) exit(1);
+ // }
+ // float diff = dl1*(4*l + min1) - xb[j];
+ // tot_mse += diff*diff;
+ // //printf(" %d %d %g\n", l, 4*l + min1, diff);
+ // }
+ // for (int j = 16; j < 32; ++j) {
+ // int l = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30);
+ // if (l != L[32*ib32 + j]) {
+ // ++nerr;
+ // printf("Oops: %d vs %u for ib32 = %d, j = %d. qs = %u (0x%02x), qh = %u (0x%02x)\n", l, L[32*ib32 + j], ib32, j, qs[j], qs[j], qh[j], qh[j]);
+ // if (nerr > 10) exit(1);
+ // }
+ // float diff = dl2*(4*l + min2) - xb[j];
+ // tot_mse += diff*diff;
+ // //printf(" %d %d %g\n", l, 4*l + min2, diff);
+ // }
+ //}
+ //printf(" after adjusting scale: d = %g, rmse = %g\n", d, sqrtf(tot_mse/256));
+
+ }
+
+}
+
+}
+
+void quantize_row_iq6_k_ref(const float * x, block_iq6_k * y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq6_k(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq6_k(const float * x, void * vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq6_k * y = (block_iq6_k *)vy;
+ quantize_row_iq6_k_ref(x, y, k);
+}
+
+size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrows; ++row) {
+ quantize_row_iq6_k_impl(src, (void *)qrow, n_per_row, imatrix);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq6_k);
+ }
+ return nrows * nblock * sizeof(block_iq6_k);
+}
+
+//
// ========================== IQ2_TN
//
@@ -1582,13 +1922,6 @@ void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) {
}
void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- GGML_UNUSED(nrc);
- GGML_UNUSED(bx);
- GGML_UNUSED(by);
- GGML_UNUSED(bs);
-
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
return;
}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 80a9012b..3c5d27a4 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -37,6 +37,12 @@ size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq6_k_ref(const float * GGML_RESTRICT x, block_iq6_k * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq6_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq6_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq6_k(const block_iq6_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq6_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);