summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-quants.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-quants.c')
-rw-r--r--ggml/src/ggml-quants.c139
1 files changed, 139 insertions, 0 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index bef2f73e..f5fff22e 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -848,6 +848,59 @@ void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_1_ref(x, y, k);
}
+void quantize_row_q6_0_ref(const float * restrict x, block_q6_0 * restrict y, int64_t k) {
+ static const int qk = QK6_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -32;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ //y[i].d = GGML_FP32_TO_FP16(d);
+ memset(y[i].qh, 0, qk/4);
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+ const float w0 = x0*x0;
+ const float w1 = x1*x1;
+
+ const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f));
+ const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f));
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ y[i].qh[j%(qk/4)] |= (h << 4*(j/(qk/4)));
+
+ const float q0 = (float)xi0 - 32.f;
+ const float q1 = (float)xi1 - 32.f;
+ sumqx += w0*x[i*qk + j]*q0 + w1*x[i*qk + qk/2 + j]*q1;
+ sumq2 += w0*q0*q0 + w1*q1*q1;
+ }
+ y[i].d = sumq2 > 0 ? GGML_FP32_TO_FP16(sumqx/sumq2) : GGML_FP32_TO_FP16(d);
+ }
+}
+
+void quantize_row_q6_0(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q6_0_ref(x, y, k);
+}
+
// reference implementation for deterministic creation of model files
void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
assert(k % QK8_0 == 0);
@@ -1691,6 +1744,28 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int6
}
}
+void dequantize_row_q6_0(const block_q6_0 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK6_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t h = x[i].qh[j%(qk/4)] >> 4*(j/(qk/4));
+
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | ((h << 4) & 0x30)) - 32;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | ((h << 2) & 0x30)) - 32;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK8_0;
@@ -3429,6 +3504,54 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}
+static void quantize_row_q6_0_impl(const float * restrict x, block_q6_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK6_0 == 32, "QK6_0 must be 32");
+
+ float weight[QK6_0];
+ int8_t L[QK6_0];
+
+ float sigma2 = 0;
+ if (quant_weights) {
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ sigma2 = sum_x2/n_per_row;
+ }
+
+ const int64_t nb = n_per_row/QK6_0;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK6_0 * ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK6_0 * ib;
+ for (int j = 0; j < QK6_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < QK6_0; ++j) weight[j] = xb[j]*xb[j];
+ }
+ float d = make_qx_quants(QK6_0, 32, xb, L, 1, weight);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+
+ memset(y[ib].qh, 0, QK6_0/4);
+
+ for (int j = 0; j < 16; ++j) {
+ const uint8_t xi0 = L[j];
+ const uint8_t xi1 = L[j+16];
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ y[ib].qh[j%8] |= (h << 4*(j/8));
+ }
+ }
+}
+
+size_t quantize_q6_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q6_0, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q6_0_impl(src, (block_q6_0*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
@@ -5383,6 +5506,21 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
*s = sumf;
}
+void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+#ifdef __AVX2__
+ const enum ggml_type vec_dot_type = GGML_TYPE_Q8_1;
+#else
+ const enum ggml_type vec_dot_type = GGML_TYPE_Q8_0;
+#endif
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q6_0, vx, bx, vec_dot_type, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ // TODO
+ *s = 0;
+}
+
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
@@ -15020,6 +15158,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
+ case GGML_TYPE_Q6_0: break;
case GGML_TYPE_IQ2_K: break;
case GGML_TYPE_IQ3_K: break;
case GGML_TYPE_IQ4_K: break;