summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-06 12:15:39 +0100
committerGitHub <noreply@github.com>2024-12-06 12:15:39 +0100
commit3682e4700db6b8cb2ca8e3da365578078f21ab0c (patch)
treeea1680494ca00580b0a038cdef035c596e80e58c
parentf64de08203aaee95ca755336de3e1db85d990198 (diff)
iq2_bn_r4: fastest Bitnet CPU implementation on the planet (#124)
* Adding iq2_bn_r4 This Zen4-only implementation achieves PP-512 = 826 t/s (!!!) for Bitnet-1.58b-3B, up from 620 t/s for iq2_bn. * Make sure rows per thread are a multiple of the number of interleaved rows With this I can run iq2_bn_r4 with 32 threads and this increases PP-512 to 872 t/s. * iq2_bn_r4: 1st shot at NEON PP-512 is already faster than iq2_bn (284 t/s vs 246 t/s for Bitnet-1.58b-3B). TG-128 is ~5% slower. * iq2_bn_r4: NEON PP-512 is now 296 t/s. TG-128 is ~20% faster than iq2_bn for 1 thread, but saturates to about the same 93 t/s at 8 threads. * iq2_bn_r4: Experimenting on NEON The matrix x vvector multiplication is erratic. iq2_bn_r4 is faster at 1, 2, and 4 threads, but saturates to a lower t/s at 8 threads compared to iq2_bn. iq2_bn actually manages 99 t/s at 8 threads and not 93 as I wrore in the last commit. iq2_bn_r4 performance has huge fluctuations at 4 and 8 threads. * Some cleanup * iq2_bn_r4: AVX2 As expected, PP is slightly slower as we just don;t have enough vector registers (690 vs 710 t/s). TG is slightly faster (18.2 vs 16.7 t/s at 1 thread). * iq2_bn_r4: use AVX2 implementation on Zen4 for matrix x vector It is faster - we get 29.6 t/s at 1 thread vs 25.9 t/s for iq2_bn. * iq2_bn_r4: simdify q8_K16 quantization (AVX2) PP-512 becomes 834 t/s and TG-128 now saturates to the same performance as iq2_bn for 4 threads. * iq2_bn_r4: simdify q8_K16 quantization (NEON) PP-512 is now 304.7 t/s, and TG-128 @ 8 threads very slightly outperforms iq2_bn (100.7 t/s vs 99.6 t/s) * iq2_bn_r4: fix AVX2 after breaking it two commits ago * iq2_bn_r4: better AVX2 As we don't have enough vector registers on AVX2, it is better to do two passes per row needing only half of the accumulator registers that way. With this, we now beat iq2_bn PP also on AVX2 by a small margin. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h3
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml-quants.h7
-rw-r--r--ggml/src/ggml.c39
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp421
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp272
-rw-r--r--ggml/src/iqk/iqk_quantize.h16
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp11
10 files changed, 738 insertions, 34 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index f8ce3edd..89fb8464 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -29,6 +29,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
+ { "IQ2_BN_R4",LLAMA_FTYPE_MOSTLY_IQ2_BN_R4," 2.00 bpw quantization (Bitnet)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 09f92eb9..962e9688 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -406,6 +406,7 @@ extern "C" {
GGML_TYPE_IQ4_KS = 144,
GGML_TYPE_IQ2_KS = 145,
GGML_TYPE_IQ4_KSS = 146,
+ GGML_TYPE_Q8_K16 = 147,
GGML_TYPE_Q4_0_R4 = 202,
GGML_TYPE_Q5_0_R4 = 206,
@@ -413,6 +414,7 @@ extern "C" {
GGML_TYPE_IQ4_NL_X4 = 220, // TODO: rename GGML_TYPE_IQ4_NL_X4 to GGML_TYPE_IQ4_NL_R4
GGML_TYPE_IQ4_XS_R4 = 223,
GGML_TYPE_Q6_0_R4 = 233,
+ GGML_TYPE_IQ2_BN_R4 = 335,
GGML_TYPE_COUNT,
};
@@ -478,6 +480,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_NL_X4 = 219, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 4fdd2c36..a9511f00 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15218,6 +15218,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I64:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
// nothing to validate
break;
default:
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index a40a6d37..b6d69011 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -33,7 +33,6 @@ void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_REST
void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
-void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_xxs_ref(const float * GGML_RESTRICT x, block_iq2_xxs * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_xs_ref (const float * GGML_RESTRICT x, block_iq2_xs * GGML_RESTRICT y, int64_t k);
@@ -43,7 +42,6 @@ void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGM
void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k);
-void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -59,7 +57,6 @@ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
-void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -69,7 +66,6 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
-void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -97,7 +93,6 @@ void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq1_bn (const block_iq1_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
-void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -123,7 +118,6 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
-void ggml_vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@@ -136,7 +130,6 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq1_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
-size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index f4320e99..12afec05 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1026,11 +1026,24 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_iq2_bn,
.from_float = quantize_row_iq2_bn,
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_ref,
- .vec_dot = ggml_vec_dot_iq2_bn_q8_K64,
+ .vec_dot = vec_dot_iq2_bn_q8_K64,
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
.row_meta_size = 4,
},
+ [GGML_TYPE_IQ2_BN_R4] = {
+ .type_name = "iq2_bn_r4",
+ .blck_size = QK_IQ1BN,
+ .type_size = sizeof(block_iq2_bn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_bn_r4,
+ .from_float = quantize_row_iq2_bn_r4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_r4_ref,
+ .vec_dot = vec_dot_iq2_bn_r4_q8_K64,
+ .vec_dot_type = GGML_TYPE_Q8_K16,
+ .nrows = 1,
+ .row_meta_size = 4,
+ },
[GGML_TYPE_IQ4_NL] = {
.type_name = "iq4_nl",
.blck_size = QK4_NL,
@@ -1103,6 +1116,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_K64,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q8_K16] = {
+ .type_name = "q8_K16",
+ .blck_size = 64,
+ .type_size = 64,
+ .is_quantized = true,
+ .from_float = quantize_row_q8_K16,
+ .row_meta_size = 20,
+ },
[GGML_TYPE_BF16] = {
.type_name = "bf16",
.blck_size = 1,
@@ -4000,6 +4021,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
+ case GGML_FTYPE_MOSTLY_IQ2_BN_R4: wtype = GGML_TYPE_IQ2_BN_R4;break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_NL_X4: wtype = GGML_TYPE_IQ4_NL_X4;break;
case GGML_FTYPE_MOSTLY_IQ4_XS_R4: wtype = GGML_TYPE_IQ4_XS_R4;break;
@@ -10529,6 +10551,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -10977,6 +11000,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -11122,6 +11146,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -14313,6 +14338,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -14698,6 +14724,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -14977,6 +15004,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -15583,6 +15611,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS_R4:
@@ -15603,6 +15632,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K:
case GGML_TYPE_Q8_K64:
+ case GGML_TYPE_Q8_K16:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
@@ -20476,7 +20506,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
if (node->src[1]->type != vec_dot_type) {
- cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
+ cur_q = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
+ //cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
}
} break;
case GGML_OP_MUL_MAT_ID:
@@ -20486,7 +20517,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
- cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1));
+ cur_q += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
+ //cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = src0->ne[2];
cur_q += GGML_PAD(cur, sizeof(int64_t)); // align
@@ -22415,6 +22447,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_BN_R4:result = quantize_iq2_bn_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL_X4: result = quantize_iq4_nl_x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS_R4: result = quantize_iq4_xs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index faa4cab7..b6ff7ab7 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -161,6 +161,17 @@ struct MulMat {
}
}
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
+ static inline int num_rows(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0_R4:
+ case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
+ case GGML_TYPE_Q8_0_R4:
+ case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ2_BN_R4: return 4;
+ default: return 1;
+ }
+ }
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
};
@@ -181,13 +192,15 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00,
size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB));
//if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA));
- auto nrc_x = (Nx + nth - 1)/nth;
+ auto num_rows = MulMat::num_rows(ggml_type(typeA));
+ GGML_ASSERT(Nx%num_rows == 0);
+ auto nrc_x = (Nx/num_rows + nth - 1)/nth;
auto first_x = ith*nrc_x;
- if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
- DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
+ DataInfo info{C + first_x*num_rows, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
- mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
+ mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x*num_rows, row_size_qx, info, nrc_x*num_rows, Ny);
return true;
}
@@ -319,6 +332,30 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
const block_q8 * y[nrc_y];
};
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
+#endif
+ inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
struct Scales8KBase {
template <typename Q8>
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
@@ -2079,6 +2116,228 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif // Zen4 or vanilla AVX2
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm256_set1_epi8(0x3);
+ auto m1 = _mm256_set1_epi16(1);
+ int nb = n / QK_IQ1BN;
+ __m256i qx[4];
+ if constexpr (nrc_y > 4) {
+ __m256i acc[nrc_y] = {};
+ __m128 sum4[nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
+ sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
+ info.store(ix, iy, s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+ } else {
+ __m256i acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
+ auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
+ auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ if constexpr (nrc_y == 1) {
+ mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm512_set1_epi8(0x3);
+ int nb = n / QK_IQ1BN;
+ __m512i acc[2*nrc_y] = {};
+ __m512i qx[8];
+ for (int ix = 0; ix < nrc_x/8; ++ix) {
+ const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
+ const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
+ auto dl = _mm_loadu_ps(dptr1);
+ auto dh = _mm_loadu_ps(dptr2);
+ const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
+ const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(bits_h, m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
+ qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
+ qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ __m128 sum4;
+ for (int k = 0; k < 2; ++k) {
+ const auto& dx = k == 0 ? dl : dh;
+ auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
+ sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(8*ix + 4*k, iy, sum4);
+ }
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
+ }
+ }
+ if (int ix = 8*(nrc_x/8); ix < nrc_x) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf = _mm512_cvtepi32_ps(acc[iy]);
+ auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -4744,6 +5003,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
+ case GGML_TYPE_IQ2_BN_R4:
+ assert (ne00 % QK_IQ1BN == 0);
+ mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
+ mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
+ mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
+ mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
+ mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
+ mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
+//#ifdef HAVE_FANCY_SIMD
+ mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
+ mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
+//#endif
+ expected_typeB = GGML_TYPE_Q8_K16;
+ break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_1_Unpacker>(mm);
@@ -7171,6 +7444,135 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
}
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+ inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
+ inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+template <int nrc_y>
+static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = vdupq_n_u8(0x3);
+ int nb = n / QK_IQ1BN;
+ if constexpr (nrc_y == 1) {
+ auto mc = vdupq_n_u8(0xc);
+ int32x4_t acc[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto y = q8.load_quants(0, ib);
+ for (int j = 0; j < 4; ++j) {
+ auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
+ auto bits2 = vshrq_n_u8(bits1, 4);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
+ }
+ }
+ auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
+ auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
+ auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
+ auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
+ info.store(ix, 0, sumf);
+ }
+ } else {
+ int32x4_t acc[4*nrc_y] = {};
+ uint8x16_t qx[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = vld1q_u8_x2(iq2 + 64*ib);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
+ }
+ bits = vld1q_u8_x2(iq2 + 64*ib + 32);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
+ info.store(ix, iy, sumf);
+ acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
+ }
+ }
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
@@ -7716,6 +8118,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
+ case GGML_TYPE_IQ2_BN_R4:
+ m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
+ m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
+ m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
+ m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
+ m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
+ //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
+ //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
+ //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
+ expected_Btype = GGML_TYPE_Q8_K16;
+ break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
expected_Btype = GGML_TYPE_Q8_0;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index acef04db..32fe92ef 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -362,7 +362,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
*s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]);
}
-void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+void vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs);
@@ -520,6 +520,136 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
}
+#ifdef __AVX2__
+namespace {
+inline float hsum_float_4(__m128 x) {
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+ return _mm_cvtss_f32(x);
+}
+inline float hsum_float_8(__m256 x) {
+ return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
+}
+inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+inline float hmax_f32_8(__m256 x) {
+ __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
+ return _mm_cvtss_f32(max4);
+}
+}
+#endif
+
+void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
+ float * dptr = (float *)vy;
+ int8_t * qy = (int8_t *)(dptr + 5);
+ int n64 = nk / 64;
+#ifdef __AVX2__
+ __m256 sign_bit = _mm256_set1_ps(-0.f);
+ __m256 vmax[4] = {};
+ __m256 vsum[4] = {};
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ auto v1 = _mm256_loadu_ps(x + 64*i64 + 16*k + 0);
+ auto v2 = _mm256_loadu_ps(x + 64*i64 + 16*k + 8);
+ vsum[k] = _mm256_add_ps(vsum[k], _mm256_add_ps(v1, v2));
+ v1 = _mm256_andnot_ps(sign_bit, v1);
+ v2 = _mm256_andnot_ps(sign_bit, v2);
+ vmax[k] = _mm256_max_ps(vmax[k], _mm256_max_ps(v1, v2));
+ }
+ }
+ __m256 sum = _mm256_add_ps(_mm256_add_ps(vsum[0], vsum[1]), _mm256_add_ps(vsum[2], vsum[3]));
+ dptr[4] = hsum_float_8(sum);
+ for (int k = 0; k < 4; ++k) {
+ float max = hmax_f32_8(vmax[k]);
+ dptr[k] = max/127;
+ vmax[k] = _mm256_set1_ps(dptr[k] > 0 ? 1/dptr[k] : 0.f);
+ }
+ __m256i ival[8];
+ const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ __m256 v0 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 0));
+ __m256 v1 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 8));
+ v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
+ v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
+ ival[2*k+0] = _mm256_cvtps_epi32(v0);
+ ival[2*k+1] = _mm256_cvtps_epi32(v1);
+ }
+ for (int k = 0; k < 2; ++k) {
+ auto i0 = _mm256_packs_epi32(ival[4*k+0], ival[4*k+1]);
+ auto i1 = _mm256_packs_epi32(ival[4*k+2], ival[4*k+3]);
+ i0 = _mm256_packs_epi16(i0, i1);
+ i0 = _mm256_permutevar8x32_epi32(i0, perm);
+ _mm256_storeu_si256((__m256i *)qy, i0);
+ qy += 32;
+ }
+ }
+#elif defined __ARM_NEON
+ static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
+ auto shuffle = vld1q_u8(k_shuffle);
+ float32x4_t vmax[4] = {};
+ float32x4_t vsum[4] = {};
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ auto v = vld1q_f32_x4(x + 64*i64 + 16*k);
+ vsum[k] = vaddq_f32(vsum[k], vaddq_f32(v.val[0], v.val[1]));
+ vsum[k] = vaddq_f32(vsum[k], vaddq_f32(v.val[2], v.val[3]));
+ vmax[k] = vmaxq_f32(vmax[k], vmaxq_f32(vabsq_f32(v.val[0]), vabsq_f32(v.val[1])));
+ vmax[k] = vmaxq_f32(vmax[k], vmaxq_f32(vabsq_f32(v.val[2]), vabsq_f32(v.val[3])));
+ }
+ }
+ dptr[4] = vaddvq_f32(vaddq_f32(vaddq_f32(vsum[0], vsum[1]), vaddq_f32(vsum[2], vsum[3])));
+ for (int k = 0; k < 4; ++k) {
+ float max = vmaxvq_f32(vmax[k]);
+ dptr[k] = max/127;
+ vmax[k] = vdupq_n_f32(dptr[k] > 0 ? 1/dptr[k] : 0.f);
+ }
+ int8x16x4_t q;
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ auto v = vld1q_f32_x4(x + 64*i64 + 16*k);
+ for (int j = 0; j < 4; ++j) {
+ q.val[j] = vreinterpretq_s8_s32(vcvtnq_s32_f32(vmulq_f32(vmax[k], v.val[j])));
+ }
+ auto qi = vqtbl4q_s8(q, shuffle);
+ vst1q_s8(qy, qi);
+ qy += 16;
+ }
+ }
+#else
+ float amax[4] = {0.f, 0.f, 0.f, 0.f};
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ for (int j = 0; j < 16; ++j) {
+ float ax = std::abs(x[64*i64 + 16*k + j]);
+ amax[k] = std::max(amax[k], ax);
+ }
+ }
+ }
+ for (int k = 0; k < 4; ++k) {
+ dptr[k] = amax[k]/127;
+ amax[k] = dptr[k] > 0 ? 1/dptr[k] : 0.f;
+ }
+ double sumf = 0;
+ for (int i64 = 0; i64 < n64; ++i64) {
+ for (int k = 0; k < 4; ++k) {
+ for (int j = 0; j < 16; ++j) {
+ sumf += x[64*i64 + 16*k + j];
+ qy[64*i64 + 16*k + j] = nearest_int(amax[k]*x[64*i64 + 16*k + j]);
+ }
+ }
+ }
+ dptr[4] = sumf;
+#endif
+}
+
//
// ============================================== iq2_K
//
@@ -2339,23 +2469,6 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
return nrows * nblock * sizeof(block_iq6_k);
}
-#ifdef __AVX2__
-namespace {
-inline int hsum_i32_8(const __m256i a) {
- const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
- const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
- const __m128i sum64 = _mm_add_epi32(hi64, sum128);
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-inline float hmax_f32_8(__m256 x) {
- __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
- max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
- return _mm_cvtss_f32(max4);
-}
-}
-#endif
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
@@ -3680,3 +3793,126 @@ void vec_dot_iq4_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
+
+//
+// ========================================= iq2_bn_r4
+//
+void quantize_row_iq2_bn_r4_ref(const float * x, block_iq2_bn * y, int64_t k) {
+ quantize_iq2_bn_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_iq2_bn_r4(const float * x, void * y, int64_t k) {
+ quantize_iq2_bn_r4(x, y, 4, k/4, nullptr);
+}
+
+namespace {
+void repack_iq2_bn(int nrows, int n_per_row, const char * x, char * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_IQ1BN == 0);
+ int nblock = n_per_row/QK_IQ1BN;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row);
+ const uint8_t * x4[4];
+ for (int row = 0; row < nrows; row += 4) {
+ float * dr4 = (float *)(y + 4*row*row_size);
+ for (int k = 0; k < 4; ++k) {
+ const float * dptr = (const float *)(x + (row + k)*row_size);
+ dr4[k] = *dptr;
+ x4[k] = (const uint8_t *)(dptr + 1);
+ }
+ uint8_t * y4 = (uint8_t *)(dr4 + 4);
+ //std::memset(y4, 0, n_per_row);
+ for (int ib = 0; ib < nblock; ++ib) {
+ // 0...3 from rows 0...3 go to 1st 2 bits of 0...15
+ // 16..19 from rows 0...3 go to 1st 2 bits of 16...31
+ // 32..35 from rows 0...3 go to 1st 2 bits of 32...47
+ // 48..51 from rows 0...3 go to 1st 2 bits of 48...63
+ // 4...7 from rows 0...3 go to 2nd 2 bits of 0...15
+ // 20..23 from rows 0...3 go to 2nd 2 bits of 16...31
+ // 36..39 from rows 0...3 go to 2nd 2 bits of 32...47
+ // 52..55 from rows 0...3 go to 2nd 2 bits of 48...63
+ // 8..11 from rows 0...3 go to 3rd 2 bits of 0...15
+ // 24..27 from rows 0...3 go to 3rd 2 bits of 16...31
+ // 40..43 from rows 0...3 go to 3rd 2 bits of 32...47
+ // 56..59 from rows 0...3 go to 3rd 2 bits of 48...63
+ // 12..15 from rows 0...3 go to 4th 2 bits of 0...15
+ // 28..31 from rows 0...3 go to 4th 2 bits of 16...31
+ // 44..47 from rows 0...3 go to 4th 2 bits of 32...47
+ // 60..63 from rows 0...3 go to 4th 2 bits of 48...63
+ for (int k = 0; k < 4; ++k) {
+ for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) {
+ y4[64*ib + 4*k + i + 16*l] = (((x4[k][16*ib + i + 0] >> 2*l) & 3) << 0) |
+ (((x4[k][16*ib + i + 4] >> 2*l) & 3) << 2) |
+ (((x4[k][16*ib + i + 8] >> 2*l) & 3) << 4) |
+ (((x4[k][16*ib + i + 12] >> 2*l) & 3) << 6);
+ //y4[64*ib + 4*k + i + 0] |= (x4[k][16*ib + i] >> 0) & 3;
+ //y4[64*ib + 4*k + i + 16] |= (x4[k][16*ib + i] >> 2) & 3;
+ //y4[64*ib + 4*k + i + 32] |= (x4[k][16*ib + i] >> 4) & 3;
+ //y4[64*ib + 4*k + i + 48] |= (x4[k][16*ib + i] >> 6) & 3;
+ //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 4] >> 0) & 3) << 2;
+ //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 4] >> 2) & 3) << 2;
+ //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 4] >> 4) & 3) << 2;
+ //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 4] >> 6) & 3) << 2;
+ //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 8] >> 0) & 3) << 4;
+ //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 8] >> 2) & 3) << 4;
+ //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 8] >> 4) & 3) << 4;
+ //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 8] >> 6) & 3) << 4;
+ //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 12] >> 0) & 3) << 6;
+ //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 12] >> 2) & 3) << 6;
+ //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 12] >> 4) & 3) << 6;
+ //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 12] >> 6) & 3) << 6;
+ }
+ }
+ }
+ }
+}
+}
+
+size_t quantize_iq2_bn_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_IQ1BN == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_iq2_bn(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_iq2_bn(4, n_per_row, qtmp.data(), qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_iq2_bn_r4(const block_iq2_bn * x, float * y, int64_t k) {
+ static_assert(QK_IQ1BN == 64);
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ const float * d4 = (const float *)x;
+ const uint8_t * qx = (const uint8_t *)(d4 + 4);
+ int nblock = n_per_row/QK_IQ1BN;
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int k = 0; k < 4; ++k) {
+ for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) {
+ uint8_t q = qx[4*k + i + 16*l];
+ y4[k][64*ib + 16*l + i + 0] = d4[k] * (((q >> 0) & 3) - 1);
+ y4[k][64*ib + 16*l + i + 4] = d4[k] * (((q >> 2) & 3) - 1);
+ y4[k][64*ib + 16*l + i + 8] = d4[k] * (((q >> 4) & 3) - 1);
+ y4[k][64*ib + 16*l + i + 12] = d4[k] * (((q >> 6) & 3) - 1);
+ }
+ }
+ qx += 64;
+ }
+}
+
+void vec_dot_iq2_bn_r4_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN_R4, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index ad2294c5..fb436e57 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -99,6 +99,22 @@ size_t quantize_iq4_xs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT
void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq4_xs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_iq2_bn_r4_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_bn_r4(const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq2_bn_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void vec_dot_iq2_bn_r4_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
#ifdef __cplusplus
}
#endif
diff --git a/include/llama.h b/include/llama.h
index 77c988a5..243fddaa 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -186,6 +186,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 = 225, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 235, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 = 237, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
diff --git a/src/llama.cpp b/src/llama.cpp
index e2abc235..ad76a7b8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3848,6 +3848,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break;
case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break;
case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break;
+ case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_NL_X4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_X4;break;
case GGML_TYPE_IQ4_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R4;break;
@@ -4576,6 +4577,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet";
case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet";
+ case LLAMA_FTYPE_MOSTLY_IQ2_BN_R4:return "IQ2_BN_R4 - 2.00 bpw Bitnet";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
@@ -15771,7 +15773,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = GGML_TYPE_IQ3_S;
}
- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) {
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN_R4) {
new_type = GGML_TYPE_IQ4_NL;
}
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
@@ -16061,7 +16063,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
++qs.n_k_quantized;
}
}
- if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN) {
+ if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) {
int nx = tensor->ne[0];
if (nx % QK_IQ1BN != 0) {
convert_incompatible_tensor = true;
@@ -16190,6 +16192,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break;
case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break;
+ case LLAMA_FTYPE_MOSTLY_IQ2_BN_R4:default_type = GGML_TYPE_IQ2_BN_R4;break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL_X4:default_type = GGML_TYPE_IQ4_NL_X4;break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:default_type = GGML_TYPE_IQ4_XS_R4;break;
@@ -16574,6 +16577,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_IQ2_BN_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN;
+ else chunk_size_multiplier = 4;
+ }
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
fflush(stdout);