summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-16 15:18:26 +0300
committerGitHub <noreply@github.com>2024-10-16 15:18:26 +0300
commit76b97c80645362ac65a2e33043fd8d46bdaf8c56 (patch)
treeb2b8ab9efb91a6ce4dd9d0fccbc9e11141ca1d80
parent993ca95e9e3108f0352fa2a3384cab0775c7f7c1 (diff)
Adding IQ4_KSS: 4.0 bpw quants (#89)
* iq4_kss: WIP * iq4_kss: CUDA dequantize works So we can run perplexity. Sadly, the result does not look good on the bpw vs quantization error plot. * iq4_kss: slightly better quantization * iq4_kss: another small quantization improvement * iq4_kss: CUDA works TG-128 performance is very decent with 131 t/s for LLaMA-3.1-8B. In comparison, we have 123 t/s for q4_0 and 128 t/s for iq4_ks. I.e., the reduced model size more than offsets the additional bit fiddling required for iq4_kss. * iq4_kss: new bit arrangement - CUDA and Zen4 work Did not lose performance on CUDA. Zen4 is decent, but not great: PP-512(LLaMA-3.1-8B) = 163 t/s. TG-128 is of course better than other 4-bit quants due to smaller model size. We get 14.5 t/s @ 8 threads. * iq4_kss: ARM_NEON. Predictably very slow * iq4_kss: Metal PP is not too bad - just 10% slower than q4_0. But TG is 30% slower, i.e., predictably bad. * iq4_kss: somewhat faster Metal dot product 45.75 t/s -> 48.75 t/s. Still 22% slower than q4_0 * iq4_kss: AVX2 Bad, but better than I expected. PP-512(LLaMA-3.1-8B) = 167 t/s on the Ryzen-5950X. I.e., with 32 AVX2 threads we get the performance of 16 Zen4 threads. * iq4_kss: very slightly faster Metal dot product 48.7 t/s -> 49.3 t/s --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize-stats/quantize-stats.cpp50
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h5
-rw-r--r--ggml/src/ggml-cuda.cu1
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu43
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu36
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu3
-rw-r--r--ggml/src/ggml-metal.m31
-rw-r--r--ggml/src/ggml-metal.metal168
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c22
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp178
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp448
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp15
19 files changed, 997 insertions, 25 deletions
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp
index 34d05bf2..ff4e9bd4 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/examples/quantize-stats/quantize-stats.cpp
@@ -256,6 +256,8 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
float mse0 = 0, mse = 0;
auto compute = [&mutex, &counter, &mse0, &mse, values, row_size, nblock, nrows, n_per_row, chunk] () {
std::vector<char> Q(row_size);
+ float diff[4];
+ float xv[4];
float lmse0 = 0, lmse = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
@@ -282,25 +284,41 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
for (int j = 0; j < 16; j += 2) {
uint16_t v0 = *(const uint16_t *)(qs + j);
int non = popcount(v0);
- float diff1 = xb[j+ 0] - dl*values[qs[j+0] & 0xf];
- float diff2 = xb[j+16] - dl*values[qs[j+0] >> 4];
- float diff3 = xb[j+ 1] - dl*values[qs[j+1] & 0xf];
- float diff4 = xb[j+17] - dl*values[qs[j+1] >> 4];
- lmse0 += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
+ xv[0] = xb[j+ 0]; xv[1] = xb[j+16]; xv[2] = xb[j+ 1]; xv[3] = xb[j+17];
+ diff[0] = xv[0] - dl*values[qs[j+0] & 0xf];
+ diff[1] = xv[1] - dl*values[qs[j+0] >> 4];
+ diff[2] = xv[2] - dl*values[qs[j+1] & 0xf];
+ diff[3] = xv[3] - dl*values[qs[j+1] >> 4];
+ float diff4 = diff[0]*diff[0] + diff[1]*diff[1] + diff[2]*diff[2] + diff[3]*diff[3];
+ lmse0 += diff4;
if (non%2 == 0) {
- lmse += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
+ lmse += diff4;
} else {
float best = std::numeric_limits<float>::max();
- for (int k = 0; k < 16; k += 4) {
- uint16_t v = v0 ^ (1 << k);
- uint8_t v1 = v;
- uint8_t v2 = v >> 8;
- diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
- diff2 = xb[j+16] - dl*values[v1 >> 4];
- diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
- diff4 = xb[j+17] - dl*values[v2 >> 4];
- float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
- if (score < best) best = score;
+ //for (int k = 0; k < 16; k += 4) {
+ // uint16_t v = v0 ^ (1 << k);
+ // uint8_t v1 = v;
+ // uint8_t v2 = v >> 8;
+ // diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
+ // diff2 = xb[j+16] - dl*values[v1 >> 4];
+ // diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
+ // diff4 = xb[j+17] - dl*values[v2 >> 4];
+ // float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
+ // if (score < best) best = score;
+ //}
+ for (int k = 0; k < 4; ++k) {
+ uint16_t v = (v0 >> 4*k) & 0xf;
+ auto pc = popcount(v);
+ if (v > 0 && popcount(v-1u) != pc) {
+ float this_diff = xv[k] - dl*values[v-1u];
+ float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
+ if (score < best) best = score;
+ }
+ if (v < 15 && popcount(v + 1u) != pc) {
+ float this_diff = xv[k] - dl*values[v+1u];
+ float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
+ if (score < best) best = score;
+ }
}
lmse += best;
}
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 1ace5720..8e0d0969 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -44,6 +44,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
+ { "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index fd7c23b9..a467c297 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -405,6 +405,7 @@ extern "C" {
GGML_TYPE_IQ1_TN = 143,
GGML_TYPE_IQ4_KS = 144,
GGML_TYPE_IQ2_KS = 145,
+ GGML_TYPE_IQ4_KSS = 146,
GGML_TYPE_COUNT,
};
@@ -462,6 +463,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 3a7b8989..f8824b0e 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -448,6 +448,11 @@ typedef struct {
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");
typedef struct {
+ uint32_t qs[QK_K/8];
+} block_iq4_kss;
+static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");
+
+typedef struct {
ggml_half d;
uint16_t extra;
uint8_t scales[QK_K/32];
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 6648b7f8..e26f36a0 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2829,6 +2829,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index a6a9c3d3..a5658a24 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -544,6 +544,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 1e4421b1..e9d15b5d 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -639,6 +639,37 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst
}
template<typename dst_t>
+static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const float *)cx;
+ const block_iq4_kss * x = (const block_iq4_kss *)(cx + sizeof(float));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + ii*QK_K + 32*ib + 4*il;
+ const uint32_t * q4 = x[i].qs + 4*ib;
+ uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
+ uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
+ const float d = scale * ((ls & 254) - 127);
+ const int8_t * values = iq4k_values + ((ls & 1) << 4);
+ uint32_t aux32[2];
+ aux32[0] = q4[il] & 0xfffefffe;
+ aux32[0] ^= (aux32[0] >> 1);
+ aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f);
+ aux32[0] &= 0x0f0f0f0f;
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * values[aux8[j+0]];
+ y[j+16] = d * values[aux8[j+4]];
+ }
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq4_k * x = (const block_iq4_k *)vx;
@@ -981,6 +1012,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t
}
template<typename dst_t>
+static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_kss<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
+}
+
+template<typename dst_t>
static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KS, n_per_row);
@@ -1152,6 +1191,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda;
+ case GGML_TYPE_IQ4_KSS:
+ return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K:
@@ -1225,6 +1266,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda;
+ case GGML_TYPE_IQ4_KSS:
+ return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K:
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 9ca219e4..dec54b5e 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -239,6 +239,35 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
return dl * __low2float(bq8_1[ib32].ds) * sumi;
}
+#define VDR_IQ4_KSS_Q8_1_MMVQ 4
+#define VDR_IQ4_KSS_Q8_1_MMQ 4
+
+__device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ float scale = *(const float *)vbq;
+ const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
+ const uint8_t * all_values = (const uint8_t *)iq4k_values;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4; // Why iqs/4 ?
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
+ uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
+ uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
+ const float dl = scale * ((ls & 254) - 127);
+ int v1, v2;
+ int sumi = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint32_t aux32 = q4[j] & 0xfffefffe;
+ aux32 ^= (aux32 >> 1);
+ get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
+ sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
+ sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
+ }
+ return dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
#define VDR_IQ5_K_Q8_1_MMVQ 4
#define VDR_IQ5_K_Q8_1_MMQ 4
@@ -703,6 +732,13 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+void mul_mat_vec_iq4_kss_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 3a93a1b6..0678c026 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -32,6 +32,10 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+void mul_mat_vec_iq4_kss_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+
void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index e312b266..107caf45 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -462,6 +462,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ4_KS:
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ4_KSS:
+ mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index ac183585..9f696383 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -107,6 +107,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K,
@@ -150,6 +151,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32,
@@ -187,6 +189,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32,
@@ -221,6 +224,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32,
@@ -255,6 +259,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32,
@@ -650,6 +655,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS, get_rows_iq4_kss, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS, get_rows_iq2_ks, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K, get_rows_iq3_k, true);
@@ -693,6 +699,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32, mul_mv_iq4_kss_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32, mul_mv_iq2_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32, mul_mv_iq3_k_f32, ctx->support_simdgroup_reduction);
@@ -730,6 +737,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32, mul_mv_id_iq4_kss_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32, mul_mv_id_iq2_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32, mul_mv_id_iq3_k_f32, ctx->support_simdgroup_reduction);
@@ -764,6 +772,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32, mul_mm_iq4_kss_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32, mul_mm_iq2_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32, mul_mm_iq3_k_f32, ctx->support_simdgroup_mm);
@@ -798,6 +807,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32, mul_mm_id_iq4_kss_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32, mul_mm_id_iq2_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32, mul_mm_id_iq3_k_f32, ctx->support_simdgroup_mm);
@@ -1997,6 +2007,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
@@ -2222,6 +2233,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_KSS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32].pipeline;
+ } break;
case GGML_TYPE_IQ2_K:
{
nth0 = 4;
@@ -2309,7 +2326,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
- src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS) {
+ src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
+ src0t == GGML_TYPE_IQ4_KSS) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2405,6 +2423,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break;
@@ -2618,6 +2637,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_KSS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline;
+ } break;
case GGML_TYPE_IQ2_K:
{
nth0 = 4;
@@ -2716,7 +2741,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
- src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS) {
+ src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
+ src0t == GGML_TYPE_IQ4_KSS) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2770,6 +2796,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break;
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index dff9326f..8981cda9 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -6142,6 +6142,117 @@ void kernel_mul_mv_iq4_ks_f32_impl(
}
}
+void kernel_mul_mv_iq4_kss_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint row_size = 4 + nb*sizeof(block_iq4_kss);
+ const uint offset0 = (i12/r2)*ne01 + (i13/r3)*(ne01*ne02);
+ device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
+ device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/16; // 0 or 1
+ const int it = tiisg%16; // 0...15
+ const int ib = it/2;
+ const int il = it%2;
+
+ shared_values[tiisg] = kvalues_iq4k_f[tiisg];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float2 sumf = 0.f;
+ float d[2];
+
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+
+ float4 qf1, qf2;
+
+ device const float * dptr = (device const float *)cx;
+ d[0] = *dptr;
+ device const uint32_t * qptr = (device const uint32_t *)(dptr + 1) + ix*(QK_K/8) + 4*ib;
+ dptr += row_size/4;
+ d[1] = *dptr;
+
+ for (int ibl = ix; ibl < nb; ibl += 2) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ device const uint32_t * q4 = qptr;
+
+ for (int row = 0; row < 2; ++row) {
+
+ uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
+ int16_t ls = (s32 | (s32 >> 15)) & 0xff;
+
+ threadgroup const float * block_values = shared_values + ((ls & 1) << 4);
+ const float scale = ((ls & 254) - 127);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ uint32_t v32 = q4[2*il+0] & 0xfffefffe;
+ v32 ^= (v32 >> 1);
+ aux32 = v32 & 0x0f0f0f0f;
+ qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
+ acc1 += yl[0] * qf1;
+ aux32 = (v32 >> 4) & 0x0f0f0f0f;
+ qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
+ acc2 += yl[1] * qf2;
+
+ v32 = q4[2*il+1] & 0xfffefffe;
+ v32 ^= (v32 >> 1);
+ aux32 = v32 & 0x0f0f0f0f;
+ qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
+ acc1 += yl[2] * qf1;
+ aux32 = (v32 >> 4) & 0x0f0f0f0f;
+ qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ sumf[row] += d[row] * scale * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ q4 += row_size/4;
+
+ }
+
+ yb += 2 * QK_K;
+ qptr += 2 * (QK_K/8);
+ }
+
+ sumf = simd_sum(sumf);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
+ }
+}
+
void kernel_mul_mv_iq2_k_f32_impl(
device const void * src0,
device const float * src1,
@@ -7098,6 +7209,35 @@ kernel void kernel_mul_mv_iq4_ks_f32(
kernel_mul_mv_iq4_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_iq4_kss_f32")]]
+kernel void kernel_mul_mv_iq4_kss_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_kss_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
[[host_name("kernel_mul_mv_iq4_k_f32")]]
kernel void kernel_mul_mv_iq4_k_f32(
device const void * src0,
@@ -7715,6 +7855,30 @@ void dequantize_iq4_ks(device const block_iq4_ks * xb, short il, thread type4x4
}
template <typename type4x4>
+void dequantize_iq4_kss(device const block_iq4_kss * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+ uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
+ uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
+ const half scale = (ls & 254) - 127;
+ constant float * values = kvalues_iq4k_f + ((ls & 1) << 4);
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = q4[i] & 0xfffefffe;
+ aux32 ^= (aux32 >> 1);
+ aux32 = (aux32 >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = scale * values[q8[0]];
+ reg[i][1] = scale * values[q8[1]];
+ reg[i][2] = scale * values[q8[2]];
+ reg[i][3] = scale * values[q8[3]];
+ }
+}
+
+template <typename type4x4>
void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
device const uint32_t * q32 = (device const uint32_t *)xb->qs + 8*(il/8) + 4*(il&1);
@@ -8378,6 +8542,7 @@ template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
+template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
//
@@ -8422,6 +8587,7 @@ template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
+template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
//
@@ -8463,6 +8629,7 @@ template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
+template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
//
@@ -8680,6 +8847,7 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_ks_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_kss_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kss_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_k_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_ks_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq3_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_k_f32_impl>>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index a845eaf5..68ec6126 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15197,6 +15197,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ2_TN: break;
case GGML_TYPE_IQ1_TN: break;
case GGML_TYPE_IQ4_KS: break;
+ case GGML_TYPE_IQ4_KSS: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index a9f795ae..35ed68d0 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1100,6 +1100,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 4,
},
+ [GGML_TYPE_IQ4_KSS] = {
+ .type_name = "iq4_kss",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_kss),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_kss,
+ .from_float = quantize_row_iq4_kss,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kss_ref,
+ .vec_dot = vec_dot_iq4_kss_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 4,
+ },
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
.blck_size = QK_K,
@@ -3918,6 +3931,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
+ case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break;
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break;
case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break;
@@ -10419,6 +10433,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -10809,6 +10824,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -10949,6 +10965,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -14135,6 +14152,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -14515,6 +14533,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -14789,6 +14808,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -15390,6 +15410,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
@@ -22208,6 +22229,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 66d26a25..7cd0dbf5 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -1209,6 +1209,67 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
};
};
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ uint32_t aux32[2];
+ auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0);
+ auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1);
+ auto bs1 = _mm512_and_si512(b1, mask15);
+ bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1));
+ auto bs2 = _mm512_and_si512(b2, mask15);
+ bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1));
+ bits.values[0] = _mm512_and_si512(bs1, bits.ml);
+ bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml);
+ bits.values[2] = _mm512_and_si512(bs2, bits.ml);
+ bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml);
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ //
+ // Now the more difficult part - prepare the scales
+ //
+ aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1);
+ aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1);
+
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32));
+ auto m1 = _mm512_castsi512_si128(mask1);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m512i values;
+ const __m512i mask15 = _mm512_set1_epi16(0xfffe);
+ const __m512i mask1 = _mm512_set1_epi16(1);
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m4 = _mm_set1_epi16(4);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
@@ -1821,8 +1882,54 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const __m128i m128 = _mm_set1_epi16(-128);
const __m128i m1 = _mm_set1_epi16(1);
const __m128i m4 = _mm_set1_epi16(4);
- const __m256i shuff1 = _mm256_set_epi64x(0x0706070605040504, 0x0302030201000100, 0x0706070605040504, 0x0302030201000100);
- const __m256i shuff2 = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908);
+};
+
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int k = 0; k < 4; ++k) {
+ data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k);
+ auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask);
+ p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p);
+ p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p);
+ helper.vec = _mm256_hadd_epi16(p, p);
+ aux[2*k+0] = helper.val[0];
+ aux[2*k+1] = helper.val[8];
+ data[k] = _mm256_and_si256(data[k], bmask);
+ data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1));
+ }
+ auto scales128 = _mm_loadu_si128((const __m128i *)aux);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int, int j) {
+ for (int k = 0; k < 2; ++k) {
+ auto p1 = _mm256_castsi256_si128(data[2*j+k]);
+ auto p2 = _mm256_extractf128_si256(data[2*j+k], 1);
+ bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml);
+ bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]);
+ bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml);
+ bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]);
+ }
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m256i values;
+ __m256i data[4];
+ const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001);
+ const __m256i bmask = _mm256_set1_epi16(0xfffe);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m256i m1 = _mm256_set1_epi16(1);
+ const __m128i m4 = _mm_set1_epi16(4);
+ uint16_t aux[8];
};
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
@@ -3848,7 +3955,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
- std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
+ std::is_same_v<Dequantizer, DequantizerIQ4KS>||
+ std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
@@ -4012,6 +4120,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KS>(mm);
break;
+ case GGML_TYPE_IQ4_KSS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ4KSS>(mm);
+ break;
case GGML_TYPE_IQ2_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2K>(mm);
@@ -4945,6 +5057,63 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const int16x8_t m127 = vdupq_n_s16(-127);
};
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+
+ DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
+ q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
+ for (int k = 0; k < 4; ++k) {
+ aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
+ aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
+ q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
+ q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
+ q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
+ q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
+ }
+ make_quants(q4bits_1, bits, aux);
+ auto scales16 = vld1q_s16(aux);
+ scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
+ bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
+ bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
+ bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
+ bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
+ bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
+ bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
+ bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
+ bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
+ }
+ inline void prepare([[maybe_unused]] int i, int j) {
+ if (j == 0) return;
+ make_quants(q4bits_2, bits, aux+4);
+ }
+ static int16x8_t load_shift() {
+ static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
+ return vld1q_s16(k_shift);
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint16x8_t mask = vdupq_n_s16(254);
+ const uint16x8_t bmask = vdupq_n_u16(0xfffe);
+ const uint16x8_t m1 = vdupq_n_u16(1);
+ const int16x8_t shift = load_shift();
+ const int16x8_t m127 = vdupq_n_s16(-127);
+ uint16x8x4_t q4bits_2;
+ int16_t aux[8];
+};
+
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
@@ -6716,6 +6885,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ4_KS:
MulMat::set_functions<DequantizerIQ4KS>(m);
break;
+ case GGML_TYPE_IQ4_KSS:
+ MulMat::set_functions<DequantizerIQ4KSS>(m);
+ break;
case GGML_TYPE_IQ2_KS:
MulMat::set_functions<DequantizerIQ2KS>(m);
break;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 43ea588b..26bc5ecb 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -20,6 +20,25 @@
#include <array>
#include <algorithm>
#include <cstring>
+#include <mutex>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#include <intrin.h>
+#include <ammintrin.h>
+#include <nmmintrin.h>
+#include <immintrin.h>
+#include <stdlib.h>
+inline int popcount(uint8_t x) { return __popcnt(x); }
+inline int popcount(uint16_t x) { return __popcnt(x); }
+inline int popcount(uint32_t x) { return __popcnt(x); }
+inline int popcount(uint64_t x) { return _mm_popcnt_u64(x); }
+#else
+constexpr int popcount(uint8_t x) { return __builtin_popcount(x); }
+constexpr int popcount(uint16_t x) { return __builtin_popcount(x); }
+constexpr int popcount(uint32_t x) { return __builtin_popcount(x); }
+constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); }
+#endif
namespace {
@@ -2811,3 +2830,432 @@ void vec_dot_iq4_ks_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
*s = sumf;
}
+namespace {
+const uint16_t * scramble_table() {
+ static std::mutex mutex;
+ static std::vector<uint16_t> table;
+ std::lock_guard<std::mutex> lock(mutex);
+ if (table.empty()) {
+ table.resize(1 << 15);
+ for (int i = 0; i < int(table.size()); ++i) {
+ uint16_t val = i;
+ int non = popcount(val);
+ if (non%2) val |= (1 << 15);
+ bool found = false;
+ for (int j = 0; j < int(table.size()); ++j) {
+ if ((j ^ (j << 1)) == val) {
+ table[i] = j; found = true; break;
+ }
+ }
+ if (!found) {
+ printf("Oops: did not find for %d %u\n", i, val);
+ exit(1);
+ }
+ }
+ }
+ return table.data();
+}
+uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const float * w, float dl) {
+ if (popcount(v)%2 == 0) return v;
+ float best_score = std::numeric_limits<float>::max();
+ uint8_t q4[4];
+ int jbest = -1;
+ uint8_t bestq = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint8_t q = (v >> 4*j) & 0xf;
+ q4[j] = q;
+ auto pc = popcount(q);
+ float diff0 = dl*iq4k_values[q] - x[j];
+ if (q > 0) {
+ uint8_t qm = q - 1u;
+ int pcm = popcount(qm);
+ if (pcm == pc-1 || pcm == pc+1) {
+ float diff1 = dl*values[qm] - x[j];
+ float score = w[j]*(diff1*diff1 - diff0*diff0);
+ if (score < best_score) {
+ best_score = score; jbest = j; bestq = qm;
+ }
+ }
+ }
+ if (q < 15) {
+ uint8_t qp = q + 1u;
+ int pcp = popcount(qp);
+ if (pcp == pc-1 || pcp == pc+1) {
+ float diff1 = dl*values[qp] - x[j];
+ float score = w[j]*(diff1*diff1 - diff0*diff0);
+ if (score < best_score) {
+ best_score = score; jbest = j; bestq = qp;
+ }
+ }
+ }
+ }
+ GGML_ASSERT(jbest >= 0);
+ q4[jbest] = bestq;
+ return (q4[0] | (q4[1] << 4) | (q4[2] << 8) | (q4[3] << 12));
+}
+static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
+ float * all_scales, float * weight,
+ const int8_t * values,
+ const float * quant_weights,
+ const uint16_t * table,
+ const int ntry) {
+
+ constexpr int super_block_size = 256;
+ constexpr int block_size = 32;
+
+ float * dptr = (float *)cy;
+ *dptr = 0;
+ block_iq4_kss * y = (block_iq4_kss *)(dptr + 1);
+
+ const int8_t * shifted_values = values + 16;
+
+ uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2];
+ float xv[4], wv[4];
+
+ float amax_scale = 0;
+
+ for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
+ memset(&y[ibl], 0, sizeof(block_iq4_kss));
+ const float * xbl = x + ibl*super_block_size;
+ auto scales = all_scales + ibl*(super_block_size/block_size);
+ float sigma2 = 0;
+ for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j];
+ sigma2 *= 2.f/super_block_size;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ const float * xb = xbl + ib*block_size;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*super_block_size + ib*block_size;
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+ }
+ float amax = 0, max = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float ax = fabsf(xb[j]);
+ if (ax > amax) {
+ amax = ax; max = xb[j];
+ }
+ }
+ if (!amax) {
+ scales[ib] = 0;
+ continue;
+ }
+ float best = 0;
+ bool is_shifted = false;
+ float d = -max/iq4k_values[0];
+ std::memset(vs, 0, block_size);
+ for (int itry = -ntry; itry <= ntry; ++itry) {
+ float id = (itry + values[0])/max;
+ float sumqx_p = 0, sumq2_p = 0;
+ float sumqx_m = 0, sumq2_m = 0;
+ float this_d = 1/id;
+ for (int k = 0; k < block_size/4; ++k) {
+ xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
+ wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
+ uint16_t vp = 0, vm = 0;
+ for (int j = 0; j < 4; ++j) {
+ float al = id*xv[j];
+ vp |= (best_index_iq4nl(values, al) << 4*j);
+ vm |= (best_index_iq4nl(values, -al) << 4*j);
+ }
+ vp = prune_iq4ks(vp, values, xv, wv, this_d);
+ vm = prune_iq4ks(vm, values, xv, wv, this_d);
+ for (int j = 0; j < 4; ++j) {
+ float w = wv[j];
+ float q = values[(vp >> 4*j) & 0xf];
+ sumqx_p += w*q*xv[j];
+ sumq2_p += w*q*q;
+ q = values[(vm >> 4*j) & 0xf];
+ sumqx_m += w*q*xv[j];
+ sumq2_m += w*q*q;
+ }
+ vps[k] = vp;
+ vms[k] = vm;
+ }
+ bool copy_p = false, copy_m = false;
+ if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
+ d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true;
+ }
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; copy_m = true;
+ }
+ if (copy_m) {
+ std::memcpy(vs, vms, block_size);
+ } else if (copy_p) {
+ std::memcpy(vs, vps, block_size);
+ }
+
+ id = (itry + shifted_values[0])/max;
+ this_d = 1/id;
+ sumqx_p = sumq2_p = 0;
+ sumqx_m = sumq2_m = 0;
+ for (int k = 0; k < block_size/4; ++k) {
+ xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
+ wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
+ uint16_t vp = 0, vm = 0;
+ for (int j = 0; j < 4; ++j) {
+ float al = id*xv[j];
+ vp |= (best_index_iq4nl(shifted_values, al) << 4*j);
+ vm |= (best_index_iq4nl(shifted_values, -al) << 4*j);
+ }
+ vp = prune_iq4ks(vp, shifted_values, xv, wv, this_d);
+ vm = prune_iq4ks(vm, shifted_values, xv, wv, this_d);
+ for (int j = 0; j < 4; ++j) {
+ float w = wv[j];
+ float q = shifted_values[(vp >> 4*j) & 0xf];
+ sumqx_p += w*q*xv[j];
+ sumq2_p += w*q*q;
+ q = shifted_values[(vm >> 4*j) & 0xf];
+ sumqx_m += w*q*xv[j];
+ sumq2_m += w*q*q;
+ }
+ vps[k] = vp;
+ vms[k] = vm;
+ }
+ copy_p = copy_m = false;
+ if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
+ d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
+ }
+ if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
+ d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
+ }
+ if (copy_m) {
+ std::memcpy(vs, vms, block_size);
+ } else if (copy_p) {
+ std::memcpy(vs, vps, block_size);
+ }
+ }
+ scales[ib] = d;
+ amax_scale = std::max(amax_scale, std::abs(d));
+ }
+ }
+ float d = amax_scale/127;
+ *dptr = d;
+ if (!d) return;
+ float id = 1/d;
+ float sumqx = 0, sumq2 = 0;
+ for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
+ auto scales = all_scales + (super_block_size/block_size)*ibl;
+ const float * xbl = x + ibl*super_block_size;
+ float sigma2 = 0;
+ for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j];
+ sigma2 *= 2.f/super_block_size;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ const float * xb = xbl + ib*block_size;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*super_block_size + ib*block_size;
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+ }
+ int l = nearest_int(0.5f*(id*scales[ib]+127.f));
+ l = (std::max(0, std::min(127, l)) << 1) - 127;
+ if (l) {
+ float dl = d*l;
+ float idl = 1/dl;
+ float mse_p = 0, mse_m = 0;
+ for (int k = 0; k < block_size/4; ++k) {
+ xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
+ wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
+ uint16_t vp = 0, vm = 0;
+ for (int j = 0; j < 4; ++j) {
+ float al = idl*xv[j];
+ vp |= (best_index_iq4nl( values, al) << 4*j);
+ vm |= (best_index_iq4nl(shifted_values, al) << 4*j);
+ }
+ vp = prune_iq4ks(vp, values, xv, wv, dl);
+ vm = prune_iq4ks(vm, shifted_values, xv, wv, dl);
+ for (int j = 0; j < 4; ++j) {
+ float w = wv[j];
+ float q = values[(vp >> 4*j) & 0xf];
+ mse_p += w*(xv[j] - dl*q)*(xv[j] - dl*q);
+ q = shifted_values[(vm >> 4*j) & 0xf];
+ mse_m += w*(xv[j] - dl*q)*(xv[j] - dl*q);
+ }
+ vps[k] = vp;
+ vms[k] = vm;
+ }
+ const uint16_t * v = vps;
+ const int8_t * block_values = values;
+ if (mse_m < mse_p) {
+ v = vms;
+ block_values = values + 16;
+ }
+ for (int k = 0; k < block_size/4; ++k) {
+ xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
+ wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
+ for (int j = 0; j < 4; ++j) {
+ float q = block_values[(v[k] >> 4*j) & 0xf] * l;
+ sumqx += wv[j]*q*xv[j];
+ sumq2 += wv[j]*q*q;
+ }
+ }
+ l += 127;
+ if (mse_m < mse_p) l |= 1;
+ uint16_t * q16 = (uint16_t *)y[ibl].qs + (block_size/4)*ib;
+ for (int k = 0; k < block_size/4; ++k) {
+ auto val = table[v[k] & 0x7fff];
+ q16[k] = (val << 1) | ((l >> k) & 1);
+ }
+ } else {
+ l += 127;
+ uint16_t * q16 = (uint16_t *)y[ibl].qs + (block_size/4)*ib;
+ for (int k = 0; k < block_size/4; ++k) {
+ q16[k] = ((l >> k) & 1);
+ }
+ }
+ }
+ }
+ if (sumq2 > 0) *dptr = sumqx/sumq2;
+}
+
+void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
+ const float * quant_weights, float * weight, float * all_scales) {
+ constexpr int kBlockSize = 32;
+ float xv[4], wv[4];
+ uint16_t vps[kBlockSize/4];
+ const float * dptr_ks = (const float *)cx;
+ const float d_ks = *dptr_ks;
+ const block_iq4_ks * iq4ks = (const block_iq4_ks *)(dptr_ks + 1);
+ float * dptr = (float *)cy;
+ *dptr = d_ks;
+ block_iq4_kss * y = (block_iq4_kss *)(dptr + 1);
+ int nblock = n_per_row/QK_K;
+ float max_abs_scale = 0;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ auto scales = all_scales + ibl*(QK_K/kBlockSize);
+ const float * xbl = x + ibl*QK_K;
+ float sigma2 = 0;
+ for (int j = 0; j < QK_K; ++j) sigma2 += xbl[j]*xbl[j];
+ sigma2 *= 2.f/QK_K;
+ const uint16_t * q4 = (const uint16_t *)iq4ks[ibl].qs;
+ for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
+ const float * xb = xbl + ib*kBlockSize;
+ if (quant_weights) {
+ const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < kBlockSize; ++j) weight[j] = xb[j]*xb[j];
+ }
+ const int8_t * values = iq4k_values + ((iq4ks[ibl].scales[ib] & 1) << 4);
+ float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127);
+ float sumqx = 0, sumq2 = 0;
+ for (int k = 0; k < kBlockSize/4; ++k) {
+ xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2];
+ wv[0] = weight[2*k+0]; wv[1] = weight[2*k+kBlockSize/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+kBlockSize/2];
+ auto vp = prune_iq4ks(q4[k], values, xv, wv, dl);
+ vps[k] = table[vp & 0x7fff];
+ for (int j = 0; j < 4; ++j) {
+ float q = values[(vp >> 4*j) & 0xf];
+ sumqx += wv[j]*q*xv[j];
+ sumq2 += wv[j]*q*q;
+ }
+ }
+ for (int k = 0; k < kBlockSize/8; ++k) {
+ y[ibl].qs[(kBlockSize/8)*ib + k] = vps[2*k+0] | (vps[2*k+1] << 15) | (((iq4ks[ibl].scales[ib] >> 2*k) & 3) << 30);
+ //y[ibl].qs[(kBlockSize/8)*ib + k] = vps[2*k+0] | (vps[2*k+1] << 15);
+ }
+ scales[ib] = sumq2 > 0 ? sumqx/sumq2 : dl;
+ max_abs_scale = std::max(max_abs_scale, scales[ib]);
+ q4 += kBlockSize/4;
+ }
+ }
+ //if (!max_abs_scale) return;
+ //float d = max_abs_scale/127;
+ //*dptr = d;
+ //float id = 1/d;
+ //for (int ibl = 0; ibl < nblock; ++ibl) {
+ // auto scales = all_scales + ibl*(QK_K/kBlockSize);
+ // for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
+ // int l = nearest_int(0.5f*(id*scales[ib]+127.f));
+ // l = std::max(0, std::min(127, l)) << 1;
+ // l |= (iq4ks[ibl].scales[ib] & 1);
+ // for (int k = 0; k < 4; ++k) {
+ // //y[ibl].qs[4*ib+k] &= 0x3fffffff;
+ // y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
+ // }
+ // }
+ //}
+}
+}
+
+size_t quantize_iq4_kss(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ constexpr int kBlockSize = 32; //128;
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
+ auto row_size_ks = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row);
+ std::vector<char> work(row_size_ks);
+ std::vector<float> all_scales(n_per_row/kBlockSize);
+ float weight[kBlockSize];
+ auto qrow = (char *)dst;
+ auto table = scramble_table();
+ for (int row = 0; row < nrows; ++row) {
+ quantize_row_iq4_kss_impl(n_per_row, src, qrow, all_scales.data(), weight, iq4k_values, imatrix, table, 7);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows * row_size;
+}
+
+void quantize_row_iq4_kss_ref(const float * x, block_iq4_kss * y, int64_t k) {
+ quantize_iq4_kss(x, y, 1, k, nullptr);
+}
+
+void quantize_row_iq4_kss(const float * x, void * y, int64_t k) {
+ quantize_iq4_kss(x, (block_iq4_kss *)y, 1, k, nullptr);
+}
+
+void dequantize_row_iq4_kss(const block_iq4_kss * x, float * y, int64_t k) {
+ const float * dptr = (const float *)x;
+ const float d = *dptr;
+ x = (const block_iq4_kss *)(dptr + 1);
+ uint16_t aux16[8];
+ const uint8_t * aux8 = (const uint8_t *)aux16;
+ for (int ibl = 0; ibl < k/QK_K; ++ibl) {
+ auto qs = (const uint16_t *)x[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ //uint8_t ls = ((qs[0] >> 30) | ((qs[1] >> 28) & 0x0c) | ((qs[2] >> 26) & 0x30) | ((qs[3] >> 24) & 0xc0));
+ //const int8_t * values = iq4k_values + ((ls & 1) << 4);
+ //const float dl = d * ((ls & 254) - 127);
+ //for (int k = 0; k < 4; ++k) {
+ // uint16_t vl = qs[k] & 0x7fff;
+ // vl ^= (vl << 1);
+ // uint16_t vh = (qs[k] >> 15) & 0x7fff;
+ // vh ^= (vh << 1);
+ // for (int j = 0; j < 4; ++j) {
+ // y[4*k + j + 0] = dl*values[(vl >> 4*j) & 0xf];
+ // y[4*k + j + 16] = dl*values[(vh >> 4*j) & 0xf];
+ // }
+ //}
+ int16_t ls = 0;
+ for (int k = 0; k < 8; ++k) {
+ aux16[k] = qs[k] & 0xfffe;
+ aux16[k] ^= (aux16[k] >> 1);
+ ls |= (qs[k] & 1) << k;
+ }
+ const int8_t * values = iq4k_values + ((ls & 1) << 4);
+ float dl = d * ((ls & 254) - 127);
+ for (int j = 0; j < 16; ++j) {
+ y[j+ 0] = dl * values[aux8[j] & 0xf];
+ y[j+16] = dl * values[aux8[j] >> 4];
+ }
+ y += 32;
+ qs += 8;
+ }
+ }
+}
+
+void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KSS, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index eb562779..e0dde0d8 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -61,6 +61,12 @@ size_t quantize_iq4_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
void dequantize_row_iq4_ks(const block_iq4_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq4_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq4_kss_ref(const float * GGML_RESTRICT x, block_iq4_kss * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_kss(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_kss(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_kss(const block_iq4_kss * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_kss_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/include/llama.h b/include/llama.h
index c9387e6b..133c2f0e 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -180,6 +180,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_KS = 145, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
diff --git a/src/llama.cpp b/src/llama.cpp
index b356f7bc..d9eec461 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3795,6 +3795,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break;
+ case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break;
case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break;
case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break;
case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break;
@@ -4498,6 +4499,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw";
@@ -15651,7 +15653,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS) {
new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K;
}
- else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && !qs.has_output) {
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) {
new_type = GGML_TYPE_IQ5_K;
}
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) {
@@ -15742,7 +15745,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
- else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && qs.model.hparams.n_gqa() >= 2) {
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) {
new_type = GGML_TYPE_IQ5_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) {
@@ -15822,7 +15826,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
}
}
- else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && !qs.has_imatrix) {
+ else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_imatrix) {
new_type = GGML_TYPE_Q5_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
@@ -15910,7 +15915,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K ||
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN ||
new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_KS ||
- new_type == GGML_TYPE_IQ2_KS) {
+ new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -15942,6 +15947,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q3_K:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
case GGML_TYPE_IQ4_K:
@@ -16055,6 +16061,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break;
+ case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break;