summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp3
-rw-r--r--examples/llama-bench/llama-bench.cpp3
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h11
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu42
-rw-r--r--ggml/src/ggml-cuda/cpy.cu50
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh44
-rw-r--r--ggml/src/ggml-metal.m33
-rw-r--r--ggml/src/ggml-metal.metal139
-rw-r--r--ggml/src/ggml-quants.c139
-rw-r--r--ggml/src/ggml-quants.h5
-rw-r--r--ggml/src/ggml.c26
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp166
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp3
19 files changed, 678 insertions, 13 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 6c298d2d..75dd78e6 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2242,6 +2242,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
+ if (s == "q6_0") {
+ return GGML_TYPE_Q6_0;
+ }
throw std::runtime_error("Invalid cache type: " + s);
}
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index fc77be50..9e4fd266 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -327,6 +327,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
+ if (s == "q6_0") {
+ return GGML_TYPE_Q6_0;
+ }
return GGML_TYPE_COUNT;
}
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index c11b8631..2b240299 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -20,6 +20,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
+ { "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", },
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 36cc531f..08fe6a3e 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -392,6 +392,7 @@ extern "C" {
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
//
+ GGML_TYPE_Q6_0 = 133,
GGML_TYPE_IQ1_BN = 134,
GGML_TYPE_IQ2_BN = 135,
GGML_TYPE_Q8_K64 = 136,
@@ -447,6 +448,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
//
+ GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_BN = 129, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_K = 130, // except 1d tensors
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index bb0c4864..02ecf071 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -95,6 +95,9 @@ typedef sycl::half2 ggml_half2;
#define QI5_1 (QK5_1 / (4 * QR5_1))
#define QR5_1 2
+#define QI6_0 (QK6_0 / (4 * QR6_0))
+#define QR6_0 2
+
#define QI8_0 (QK8_0 / (4 * QR8_0))
#define QR8_0 1
@@ -187,6 +190,14 @@ typedef struct {
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+#define QK6_0 32
+typedef struct {
+ ggml_half d; // delta
+ uint8_t qh[QK6_0/4]; // 5+6-th bit of quants
+ uint8_t qs[QK6_0/2]; // nibbles / quants
+} block_q6_0;
+static_assert(sizeof(block_q6_0) == sizeof(ggml_half) + QK6_0/2 + QK6_0/4, "wrong q6_0 block size/padding");
+
#define QK8_0 32
typedef struct {
ggml_half d; // delta
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 617dd58f..64cc7592 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2807,6 +2807,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
@@ -2880,6 +2881,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q6_0) {
+ return true;
+ }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index d75b219b..d7e9c529 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -376,6 +376,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_0> {
+ static constexpr int qk = QK6_0;
+ static constexpr int qr = QR6_0;
+ static constexpr int qi = QI6_0;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index c74b030b..7089a6df 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -129,6 +129,36 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_q6_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q6_0 * x = (const block_q6_0 *)vx + ib;
+ const float d = __half2float(x->d);
+ const float dm = -32*d;
+
+ const uint8_t * qs = x->qs + 4*il;
+ const uint8_t * qh = x->qh + 4*(il%2);
+
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t h = qh[l] >> 4*(il/2);
+ y[l+ 0] = d * ((qs[l] & 0xF) | ((h << 4) & 0x30)) + dm;
+ y[l+16] = d * ((qs[l] >> 4) | ((h << 2) & 0x30)) + dm;
+ }
+}
+
//================================== k-quants
template<typename dst_t>
@@ -768,6 +798,14 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n
}
template<typename dst_t>
+static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q6_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
@@ -1004,6 +1042,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q6_0:
+ return dequantize_row_q6_0_cuda;
case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
@@ -1074,6 +1114,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q6_0:
+ return dequantize_row_q6_0_cuda;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 1a84a4cb..0b269a86 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -221,6 +221,41 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh));
}
+static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q6_0 * dsti = (block_q6_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK6_0; ++j) {
+ const float v = xi[j];
+ const float av = fabsf(xi[j]);
+ if (amax < av) {
+ amax = av;
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -32;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+ memset(dsti->qh, 0, QK6_0/4);
+
+ for (int j = 0; j < QK6_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
+ const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
+
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
+ }
+}
+
static __device__ const int8_t iq4nl_index[241] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
@@ -397,6 +432,17 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
+static void ggml_cpy_f32_q6_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK6_0 == 0);
+ const int num_blocks = ne / QK6_0;
+ cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -466,6 +512,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
+ ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
@@ -505,6 +553,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 5f932fef..15e8fb5a 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -9,6 +9,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+ type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
@@ -34,6 +35,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
+ type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
@@ -232,6 +234,13 @@ static void mul_mat_vec_q5_1_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+static void mul_mat_vec_q6_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -384,6 +393,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_Q6_0:
+ mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index b1b465a3..7baabb7a 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -48,6 +48,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
}
+#define VDR_Q6_0_Q8_1_MMVQ 2
+#define VDR_Q6_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q6_0_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const float & d6, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = ((vl[i] >> 0) & 0x0F0F0F0F) | ((vh[i] << 4) & 0x30303030);
+ const int vi1 = ((vl[i] >> 4) & 0x0F0F0F0F) | ((vh[i] << 2) & 0x30303030);
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 8 from each quant value
+ return d6 * (sumi * ds8f.x - (32.f*vdr/QI6_0) * ds8f.y);
+}
+
#define VDR_Q4_1_Q8_1_MMVQ 2
#define VDR_Q4_1_Q8_1_MMQ 4
@@ -549,6 +573,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
}
+static __device__ __forceinline__ float vec_dot_q6_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q6_0 * bq6_0 = (const block_q6_0 *) vbq + kbx;
+
+ int vl[VDR_Q6_0_Q8_1_MMVQ];
+ int vh[VDR_Q6_0_Q8_1_MMVQ];
+ int u[2*VDR_Q6_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q6_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b2(bq6_0->qs, iqs + i);
+ vh[i] = get_int_b2(bq6_0->qh, i) >> 4*(iqs/2);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI6_0);
+ }
+
+ return vec_dot_q6_0_q8_1_impl<VDR_Q6_0_Q8_1_MMVQ>(vl, vh, u, bq6_0->d, bq8_1->ds);
+}
+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 774314df..dcdd0efe 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -81,6 +81,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
@@ -121,6 +122,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
@@ -155,6 +157,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
@@ -186,6 +189,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
@@ -217,6 +221,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
@@ -271,6 +276,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
@@ -603,6 +609,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0, get_rows_q6_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
@@ -643,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32, mul_mv_q6_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
@@ -677,6 +685,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32, mul_mv_id_q6_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
@@ -708,6 +717,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32, mul_mm_q6_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
@@ -739,6 +749,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32, mul_mm_id_q6_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
@@ -793,6 +804,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
@@ -960,6 +972,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
@@ -1910,6 +1923,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
@@ -2028,6 +2042,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
} break;
+ case GGML_TYPE_Q6_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32].pipeline;
+ } break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
@@ -2200,7 +2220,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
@@ -2293,6 +2313,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
@@ -2398,6 +2419,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
} break;
+ case GGML_TYPE_Q6_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline;
+ } break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
@@ -2581,7 +2608,7 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t _ne1 = 1;
const int tgz = dst_rows;
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
@@ -2632,6 +2659,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
@@ -3293,6 +3321,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
default: GGML_ABORT("not implemented");
};
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index c1e11047..225fa5f1 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -1281,8 +1281,30 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// function for calculate inner product between half a q6_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q6 quants begin (0 or QK6_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qh = (device const uint16_t *)qb_curr->qh;
+ device const uint16_t * qs = (device const uint16_t *)qb_curr->qs + il/2;
+
+ const int shift = 4*(il/8);
+ for (int i = 0; i < 8; i += 2) {
+ acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << (4-shift)) & 0x0030))
+ + yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << (4-shift)) & 0x3000));
+ acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << (6-shift)) & 0x0300))
+ + yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << (6-shift)) & 0x30000));
}
- return d * (acc[0] + acc[1]) + sumy * m;
+ return d * (sumy * -32.f + acc[0] + acc[1]);
}
// putting them in the kernel cause a significant performance penalty
@@ -1464,6 +1486,31 @@ kernel void kernel_mul_mv_q5_1_f32(
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
+kernel void kernel_mul_mv_q6_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q6_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
#define NB_Q8_0 8
@@ -3480,6 +3527,77 @@ kernel void kernel_cpy_f32_q5_1(
}
}
+kernel void kernel_cpy_f32_q6_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK6_0;
+
+ device block_q6_0 * dst_data = (device block_q6_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK6_0; i00 < ne00; i00 += ntg.x*QK6_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK6_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -32;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ device block_q6_0 & b6 = dst_data[i00/QK6_0];
+ b6.d = d;
+ device uint16_t * aux16 = (device uint16_t *)b6.qh;
+ aux16[0] = aux16[1] = aux16[2] = aux16[3] = 0;
+
+ for (int j = 0; j < QK6_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK6_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f));
+ const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f));
+
+ b6.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ b6.qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
+ }
+ }
+}
+
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
@@ -6844,6 +6962,21 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
}
template <typename type4x4>
+void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg) {
+ const float d = xb->d;
+ const float m = -32.h * xb->d;
+ device const uint8_t * qh = xb->qh;
+ device const uint8_t * qs = qh + 8;
+
+ for (int i = 0; i < 8; i++) {
+ reg[i/4][i%4] = d * (((qs[i] >> 4*il) & 0xf) | (((qh[i] >> 2*il) << 4) & 0x30)) + m;
+ }
+ for (int i = 0; i < 8; i++) {
+ reg[2+i/4][i%4] = d * (((qs[i+8] >> 4*il) & 0xf) | ((qh[i] >> 2*il) & 0x30)) + m;
+ }
+}
+
+template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
@@ -7839,6 +7972,7 @@ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q6_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_0, 2, dequantize_q6_0>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -7880,6 +8014,7 @@ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>;
+template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_0, 2, dequantize_q6_0>>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>;
@@ -7918,6 +8053,7 @@ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>;
+template [[host_name("kernel_mul_mm_id_q6_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q6_0, 2, dequantize_q6_0>>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>;
@@ -8138,6 +8274,7 @@ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q6_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q6_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_tn_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index bef2f73e..f5fff22e 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -848,6 +848,59 @@ void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_1_ref(x, y, k);
}
+void quantize_row_q6_0_ref(const float * restrict x, block_q6_0 * restrict y, int64_t k) {
+ static const int qk = QK6_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -32;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ //y[i].d = GGML_FP32_TO_FP16(d);
+ memset(y[i].qh, 0, qk/4);
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+ const float w0 = x0*x0;
+ const float w1 = x1*x1;
+
+ const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f));
+ const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f));
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ y[i].qh[j%(qk/4)] |= (h << 4*(j/(qk/4)));
+
+ const float q0 = (float)xi0 - 32.f;
+ const float q1 = (float)xi1 - 32.f;
+ sumqx += w0*x[i*qk + j]*q0 + w1*x[i*qk + qk/2 + j]*q1;
+ sumq2 += w0*q0*q0 + w1*q1*q1;
+ }
+ y[i].d = sumq2 > 0 ? GGML_FP32_TO_FP16(sumqx/sumq2) : GGML_FP32_TO_FP16(d);
+ }
+}
+
+void quantize_row_q6_0(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q6_0_ref(x, y, k);
+}
+
// reference implementation for deterministic creation of model files
void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
assert(k % QK8_0 == 0);
@@ -1691,6 +1744,28 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int6
}
}
+void dequantize_row_q6_0(const block_q6_0 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK6_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t h = x[i].qh[j%(qk/4)] >> 4*(j/(qk/4));
+
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | ((h << 4) & 0x30)) - 32;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | ((h << 2) & 0x30)) - 32;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK8_0;
@@ -3429,6 +3504,54 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}
+static void quantize_row_q6_0_impl(const float * restrict x, block_q6_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK6_0 == 32, "QK6_0 must be 32");
+
+ float weight[QK6_0];
+ int8_t L[QK6_0];
+
+ float sigma2 = 0;
+ if (quant_weights) {
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ sigma2 = sum_x2/n_per_row;
+ }
+
+ const int64_t nb = n_per_row/QK6_0;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK6_0 * ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK6_0 * ib;
+ for (int j = 0; j < QK6_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < QK6_0; ++j) weight[j] = xb[j]*xb[j];
+ }
+ float d = make_qx_quants(QK6_0, 32, xb, L, 1, weight);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+
+ memset(y[ib].qh, 0, QK6_0/4);
+
+ for (int j = 0; j < 16; ++j) {
+ const uint8_t xi0 = L[j];
+ const uint8_t xi1 = L[j+16];
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+ const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
+ y[ib].qh[j%8] |= (h << 4*(j/8));
+ }
+ }
+}
+
+size_t quantize_q6_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q6_0, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q6_0_impl(src, (block_q6_0*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
@@ -5383,6 +5506,21 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
*s = sumf;
}
+void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+#ifdef __AVX2__
+ const enum ggml_type vec_dot_type = GGML_TYPE_Q8_1;
+#else
+ const enum ggml_type vec_dot_type = GGML_TYPE_Q8_0;
+#endif
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q6_0, vx, bx, vec_dot_type, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ // TODO
+ *s = 0;
+}
+
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
@@ -15020,6 +15158,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
+ case GGML_TYPE_Q6_0: break;
case GGML_TYPE_IQ2_K: break;
case GGML_TYPE_IQ3_K: break;
case GGML_TYPE_IQ4_K: break;
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index 775aa875..bad7e9d9 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -25,6 +25,7 @@ void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_REST
void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_0_ref(const float * GGML_RESTRICT x, block_q6_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@@ -48,6 +49,7 @@ void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -72,6 +74,7 @@ void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRI
void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q6_0(const block_q6_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -98,6 +101,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -140,6 +144,7 @@ size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q6_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index ee83fc43..d31713df 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -799,6 +799,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q6_0] = {
+ .type_name = "q6_0",
+ .blck_size = QK6_0,
+ .type_size = sizeof(block_q6_0),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q6_0,
+ .from_float = quantize_row_q6_0,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref,
+ .vec_dot = ggml_vec_dot_q6_0_q8_0,
+#if GGML_USE_IQK_MULMAT && defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0,
+#endif
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q8_0] = {
.type_name = "q8_0",
.blck_size = QK8_0,
@@ -3788,6 +3805,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
+ case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@@ -10237,6 +10255,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
@@ -10623,6 +10642,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
@@ -10760,6 +10780,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
@@ -13858,6 +13879,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
@@ -14234,6 +14256,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
@@ -14505,6 +14528,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
@@ -15103,6 +15127,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
@@ -21899,6 +21924,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index d16f01d9..0c1c1625 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -3228,6 +3228,16 @@ struct Q5_1_Dequantizer {
return _mm256_or_si256(b4.dequant(x->qs), vqh);
}
};
+struct Q6_1_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i mh = _mm256_set1_epi8(0x30);
+ inline __m256i dequant(const block_q6_0 * x) const {
+ uint64_t aux64; std::memcpy(&aux64, x->qh, 8);
+ auto h128 = _mm_set_epi64x(aux64, aux64 << 4);
+ auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 2), h128);
+ return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(h256, mh));
+ }
+};
template <typename Q, typename Scales, typename Dequantizer>
struct Q_Unpacker {
@@ -3332,6 +3342,11 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
using Sum4T = Sum4Type1;
inline static int block_size() { return QK4_1; }
};
+struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_1_Dequantizer> {
+ Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ81;
+ inline static int block_size() { return QK5_0; }
+};
// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
@@ -3628,7 +3643,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
- std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
+ std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> ||
+ std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;
@@ -3893,8 +3909,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
- //MulMat::set_functions<Q4_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q4_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -3905,8 +3919,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q5_0:
assert (ne00 % QK5_0 == 0);
- //MulMat::set_functions<Q5_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q5_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -3915,10 +3927,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
MulMat::set_functions<Q5_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_Q6_0:
+ assert (ne00 % QK6_0 == 0);
+ MulMat::set_functions<Q6_0_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
- //MulMat::set_functions<Q8_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q8_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -5417,6 +5432,34 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {
//ggml_half aux[4];
};
+struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> {
+
+ DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh8 = vld1_u8(x[i].qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32);
+ q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ const int8x16_t m32 = vdupq_n_s8(-32);
+ const uint8x16_t hmask = vdupq_n_u8(0x30);
+};
+
struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> {
DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
@@ -6325,7 +6368,8 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
- std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL>) {
+ std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> ||
+ std::is_same_v<Dequantizer, DequantizerQ60>) {
m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;
@@ -6492,6 +6536,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerQ51>(m);
expected_Btype = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_Q6_0:
+ MulMat::set_functions<DequantizerQ60>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
expected_Btype = GGML_TYPE_Q8_0;
@@ -7227,6 +7275,64 @@ struct HelperIQ4nl final : public BaseHelper<step> {
#endif
};
+template <int D, int step>
+struct HelperQ60 final : public BaseHelper<step> {
+#ifdef __aarch64__
+ using block_q8 = block_q8_0;
+#else
+ using block_q8 = block_q8_1;
+#endif
+ using Base = BaseHelper<step>;
+ HelperQ60(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
+#ifdef __aarch64__
+ // TODO
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto qh8 = vld1_u8(dl->qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ auto qs = vld1q_u8(dl->qs);
+ qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l);
+ qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh));
+ qs = vaddq_s8(qs, m32);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
+ uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
+ auto bh = _mm_set_epi64x(aux64, aux64 << 4);
+#ifdef HAVE_FANCY_SIMD
+ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
+ auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) {
+ bl = _mm_srli_epi16(bl, 4);
+ bh = _mm_srli_epi16(bh, 2);
+ }
+ auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __AVX2__
+ const __m128i mask_l = _mm_set1_epi8(0x0f);
+ const __m128i mask_h = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+#else
+ const uint8x16_t mask_l = vdupq_n_u8(0x0f);
+ const uint8x16_t mask_h = vdupq_n_u8(0x30);
+ const int8x16_t m32 = vdupq_n_s8(-32);
+#endif
+};
+
template <int q_step, int k_step>
struct FlashMS {
// Something goes wrong when storing and manipulating K*Q as fp16.
@@ -7759,6 +7865,14 @@ struct FlashQKfp32 {
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
else {
GGML_ASSERT(false);
}
@@ -7880,6 +7994,28 @@ struct FlashQKfp32 {
#endif
}
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
else {
GGML_ASSERT(false);
}
@@ -8019,7 +8155,8 @@ struct FlashAttn {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
- std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
+ std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
+ std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {
@@ -8364,6 +8501,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperIQ4nl<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
default: break;
}
}
@@ -8395,6 +8536,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperIQ4nl<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
default: break;
}
@@ -8404,7 +8549,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true;
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
+ type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
return false;
}
}
diff --git a/include/llama.h b/include/llama.h
index 02d94b6c..43c0091e 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -167,6 +167,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
//
+ LLAMA_FTYPE_MOSTLY_Q6_0 = 135, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_BN = 137, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_K = 138, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index dca03ade..eb982125 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3774,6 +3774,7 @@ struct llama_model_loader {
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
+ case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break;
case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
@@ -4471,6 +4472,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
+ case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
@@ -15967,6 +15969,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
+ case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;