summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize-stats/quantize-stats.cpp8
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml-common.h8
-rw-r--r--ggml-quants.c1
-rw-r--r--ggml-quants.h5
-rw-r--r--ggml.c22
-rw-r--r--ggml.h4
-rw-r--r--iqk-quantize.cpp343
-rw-r--r--llama.cpp7
-rw-r--r--llama.h1
10 files changed, 216 insertions, 184 deletions
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp
index 746df844..4eb8f953 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/examples/quantize-stats/quantize-stats.cpp
@@ -341,6 +341,10 @@ int main(int argc, char ** argv) {
if (!layer_included(params, kv_tensor.first)) {
continue;
}
+ if (kv_tensor.second->ne[0] == 1 || kv_tensor.second->ne[1] == 1) {
+ // we never quantize those
+ continue;
+ }
if (params.verbose) {
printf("%s: type %s, size %" PRId64 "\n", kv_tensor.first.c_str(), ggml_type_name(kv_tensor.second->type), ggml_nelements(kv_tensor.second));
}
@@ -386,6 +390,10 @@ int main(int argc, char ** argv) {
if (!layer_included(params, kv_tensor.first)) {
continue;
}
+ if (kv_tensor.second->ne[0] == 1 || kv_tensor.second->ne[1] == 1) {
+ // we never quantize those
+ continue;
+ }
if (params.verbose) {
printf(" %s ...\n", kv_tensor.first.c_str());
}
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 06927890..a5ffb2b2 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -27,6 +27,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.75 bpw quantization (Bitnet)", },
+ { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
diff --git a/ggml-common.h b/ggml-common.h
index 148b41d5..c7f865e8 100644
--- a/ggml-common.h
+++ b/ggml-common.h
@@ -380,6 +380,14 @@ typedef struct {
uint8_t qh[QK_IQ1BN/16];
} block_iq1_bn;
static_assert(sizeof(block_iq1_bn) == sizeof(uint16_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding");
+//
+// Bitnet - implemented as 2.0 bpw
+//
+#define QK_IQ2BN 64
+typedef struct {
+ uint8_t qs[QK_IQ2BN/4];
+} block_iq2_bn;
+static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
// Used by IQ1_M quants
typedef union {
diff --git a/ggml-quants.c b/ggml-quants.c
index 31817b1c..f1ce1345 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -15056,6 +15056,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
// nothing to validate
break;
default:
diff --git a/ggml-quants.h b/ggml-quants.h
index a66fe6d6..cc677207 100644
--- a/ggml-quants.h
+++ b/ggml-quants.h
@@ -33,6 +33,7 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq1_bn_reference (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn_reference (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -55,6 +56,7 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -81,6 +83,7 @@ void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq1_bn (const block_iq1_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -106,6 +109,7 @@ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_bn_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@@ -118,6 +122,7 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq1_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/ggml.c b/ggml.c
index e1439e2f..4d089c43 100644
--- a/ggml.c
+++ b/ggml.c
@@ -874,6 +874,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
},
+ [GGML_TYPE_IQ2_BN] = {
+ .type_name = "iq2_bn",
+ .blck_size = QK_IQ1BN,
+ .type_size = sizeof(block_iq2_bn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_bn,
+ .from_float = quantize_row_iq2_bn,
+ .from_float_reference = (ggml_from_float_t)quantize_row_iq2_bn_reference,
+ .vec_dot = ggml_vec_dot_iq2_bn_q8_K64,
+ .vec_dot_type = GGML_TYPE_Q8_K64,
+ .nrows = 1,
+ },
[GGML_TYPE_IQ4_NL] = {
.type_name = "iq4_nl",
.blck_size = QK4_NL,
@@ -3206,6 +3218,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
+ case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
@@ -9500,6 +9513,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -9900,6 +9914,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -10030,6 +10045,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -13003,6 +13019,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -13197,6 +13214,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -13473,6 +13491,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -14087,6 +14106,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
@@ -21082,6 +21102,7 @@ void ggml_quantize_init(enum ggml_type type) {
case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
+ case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ1_BN: iq1bn_init_impl(); break;
default: // nothing
break;
@@ -21152,6 +21173,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16:
diff --git a/ggml.h b/ggml.h
index 070af417..3d6e4283 100644
--- a/ggml.h
+++ b/ggml.h
@@ -384,7 +384,8 @@ extern "C" {
GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30,
GGML_TYPE_IQ1_BN = 31,
- GGML_TYPE_Q8_K64 = 32,
+ GGML_TYPE_IQ2_BN = 32,
+ GGML_TYPE_Q8_K64 = 33,
GGML_TYPE_COUNT,
};
@@ -427,6 +428,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_BN = 25, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_BN = 26, // except 1d tensors
};
// available tensor operations:
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp
index 1cc598bf..287a7dd3 100644
--- a/iqk-quantize.cpp
+++ b/iqk-quantize.cpp
@@ -87,24 +87,46 @@ struct IQ1BNQuantizer {
} scale_t;
constexpr static int block_size = QK_IQ1BN;
int8_t L[QK_IQ1BN];
- void quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
+ void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
+ void quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix);
+ static inline float row_max(int n_per_row, const float * src) {
+ float max_in_row = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ float ax = fabsf(src[j]);
+ max_in_row = std::max(max_in_row, ax);
+ }
+ return max_in_row;
+ }
+ static uint16_t quantize_one_block_1bn(const IQ1BNData& iq1l, const float * xb, int8_t * L, uint8_t * ql, uint8_t * qh);
};
-void IQ1BNQuantizer::quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
+uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const float * xb, int8_t * L, uint8_t * ql, uint8_t * qh) {
+ for (int j = 0; j < QK_IQ1BN; ++j) {
+ L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
+ }
+ uint16_t extra = 0;
+ for (int k = 0; k < QK_IQ1BN/8; ++k) {
+ auto Lk = L + 8*k;
+ uint16_t u = 0;
+ for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j);
+ auto& val = iq1bn.map[u];
+ GGML_ASSERT(val.first >= 0);
+ ql[k] = val.first & 255;
+ qh[k/2] |= (val.first >> 8) << 4*(k%2);
+ if (val.second) extra |= (1 << k);
+ }
+ return extra;
+}
- (void)imatrix;
+void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
- constexpr int Nk = block_size/8;
+ (void)imatrix;
const int nblock = n_per_row/QK_IQ1BN;
const auto& iq1bn = get_iq1bn_data();
- float max_in_row = 0;
- for (int j = 0; j < n_per_row; ++j) {
- float ax = fabsf(src[j]);
- max_in_row = std::max(max_in_row, ax);
- }
+ auto max_in_row = row_max(n_per_row, src);
max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation
scale_t s;
@@ -126,27 +148,45 @@ void IQ1BNQuantizer::quantize_one_row(const float * src, block_iq1_bn * y, int n
for (int ib = 0; ib < nblock; ++ib) {
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
auto xb = src + QK_IQ1BN*ib;
+ auto extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh);
+ y[ib].extra = u | (extra << 8);
+ }
+}
+
+void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) {
+
+ (void)imatrix;
+
+ const int nblock = n_per_row/QK_IQ1BN;
+
+ const auto& iq1bn = get_iq1bn_data();
+
+ auto max_in_row = row_max(n_per_row, src);
+
+ ggml_half * d = (ggml_half *)y;
+ *d = GGML_FP32_TO_FP16(max_in_row);
+
+ auto ql = (uint8_t *)(d + 2);
+ auto qh = ql + QK_IQ1BN/8;
+ std::memset(ql, 0, QK_IQ1BN/8);
+ std::memset(qh, 0, QK_IQ1BN/16);
+ auto xb = src;
+ auto extra = quantize_one_block_1bn(iq1bn, xb, L, ql, qh);
+ *(uint16_t *)(d + 1) = extra;
+
+ constexpr int Nj = QK_IQ1BN/4;
+
+ for (int ib = 1; ib < nblock; ++ib) {
+ xb = src + QK_IQ1BN*ib;
for (int j = 0; j < QK_IQ1BN; ++j) {
L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
}
- auto ql = y[ib].ql;
- auto qh = y[ib].qh;
- uint16_t extra = 0;
- for (int k = 0; k < Nk; ++k) {
- auto Lk = L + 8*k;
- uint16_t u = 0;
- for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j);
- auto& val = iq1bn.map[u];
- GGML_ASSERT(val.first >= 0);
- ql[k] = val.first & 255;
- qh[k/2] |= (val.first >> 8) << 4*(k%2);
- if (val.second) extra |= (1 << k);
+ for (int j = 0; j < Nj; ++j) {
+ y[ib].qs[j] = L[j] | (L[j + Nj] << 2) | (L[j + 2*Nj] << 4) | (L[j + 3*Nj] << 6);
}
-
- y[ib].extra = u | (extra << 8);
-
}
}
+
}
void iq1bn_init_impl(void) {
@@ -158,7 +198,7 @@ size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_p
int nblock = n_per_row/QK_IQ1BN;
block_iq1_bn * y = (block_iq1_bn *)dst;
for (int row = 0; row < nrows; ++row) {
- iq1bn.quantize_one_row(src + row*n_per_row, y, n_per_row, imatrix);
+ iq1bn.quantize_one_row_1bn(src + row*n_per_row, y, n_per_row, imatrix);
y += nblock;
}
return sizeof(block_iq1_bn)*nblock*nrows;
@@ -195,16 +235,54 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
}
}
-#if __AVX__ || __AVX2__ || __AVX512F__
-// horizontally add 8 floats
-static inline float hsum_float_8(const __m256 x) {
- __m128 res = _mm256_extractf128_ps(x, 1);
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
- return _mm_cvtss_f32(res);
+size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ IQ1BNQuantizer iq1bn;
+ int nblock = n_per_row/QK_IQ1BN;
+ block_iq2_bn * y = (block_iq2_bn *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ iq1bn.quantize_one_row_2bn(src + row*n_per_row, y, n_per_row, imatrix);
+ y += nblock;
+ }
+ return sizeof(block_iq2_bn)*nblock*nrows;
+}
+
+void quantize_row_iq2_bn_reference(const float * x, block_iq2_bn * y, int64_t k) {
+ quantize_iq2_bn(x, y, 1, k, nullptr);
+}
+
+void quantize_row_iq2_bn(const float * x, void * y, int64_t k) {
+ quantize_iq2_bn(x, y, 1, k, nullptr);
+}
+
+void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) {
+ assert(k%QK_IQ1BN == 0);
+ int nblock = k / QK_IQ1BN;
+
+ float d = GGML_FP16_TO_FP32(*(const ggml_half *)x);
+ auto * extra_ptr = (const uint16_t *)x;
+ auto extra = extra_ptr[1];
+ auto ql = (const uint8_t *)(extra_ptr + 2);
+ auto qh = ql + QK_IQ1BN/8;
+ for (int l = 0; l < QK_IQ1BN/8; ++l) {
+ uint16_t idx = ql[l] | ((qh[l/2] << (8 - 4*(l%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ float dls = extra & (1 << l) ? -d : d;
+ for (int j = 0; j < 8; ++j) y[j] = dls * (((val >> 2*j) & 3) - 1);
+ y += 8;
+ }
+ auto m = -d;
+ auto d1 = d, d2 = d*0.25f, d3 = d2*0.25f, d4 = d3*0.25f;
+ constexpr int Nj = QK_IQ1BN/4;
+ for (int i = 1; i < nblock; ++i) {
+ for (int j = 0; j < Nj; ++j) {
+ y[j+ 0] = d1*(x[i].qs[j] & 0x03) + m;
+ y[j+1*Nj] = d2*(x[i].qs[j] & 0x0c) + m;
+ y[j+2*Nj] = d3*(x[i].qs[j] & 0x30) + m;
+ y[j+3*Nj] = d4*(x[i].qs[j] & 0xc0) + m;
+ }
+ y += QK_IQ1BN;
+ }
}
-#endif
void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
@@ -222,79 +300,6 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
float sumf = 0;
IQ1BNQuantizer::scale_t scale;
-#if defined __AVX2__
-
- const auto m1_8 = _mm256_set1_epi8(1);
- const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
- const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
- const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
- const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
-#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
- const auto m1_16 = _mm256_set1_epi16(1);
-#endif
-
- __m256 acc1 = _mm256_setzero_ps();
- __m256 acc2 = _mm256_setzero_ps();
-
- // All scales are the same in BitNet!
- uint16_t u = x[0].extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
-
- for (int i = 0; i < nblock; ++i) {
- // We would uncomment this if we wanted to use this implementation for a model that has per block scales
- //uint16_t u = x[i].extra & 0xff;
- //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
- auto signs = _mm256_set1_epi8(x[i].extra >> 8);
- // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
- // To use these to sign the q8 values we need
- // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
- signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
- auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+0].qs), _mm256_shuffle_epi8(signs, shuff3));
- auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+1].qs), _mm256_shuffle_epi8(signs, shuff4));
-
- auto ql = x[i].ql;
- auto qh = x[i].qh;
- auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
- auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
-
- auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
- auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
- auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
- auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
-
- auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
- auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
-
-#if defined __AVX512VNNI__ && defined __AVX512VL__
- dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
- dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
-#else
- dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
- dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
-#endif
-
- // We would uncomment this if we wanted to use this implementation for a model that has per block scales
- //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
- //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
- // All scales are the same for BitNet!
- // This is slower
- //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
- //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
- //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
- //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
- acc1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
- acc2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
-
- }
-
- //sumf = hsum_float_8(_mm256_add_ps(acc1, acc2));
- sumf = scale.f * hsum_float_8(_mm256_add_ps(acc1, acc2));
-
-#else
-
for (int i = 0; i < nblock; ++i) {
uint16_t u = x[i].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
@@ -324,8 +329,6 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
}
-#endif
-
*s = sumf;
}
@@ -346,76 +349,6 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
float sumf = 0;
IQ1BNQuantizer::scale_t scale;
-#if defined __AVX2__
-
- const auto m1_8 = _mm256_set1_epi8(1);
- const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
- const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
- const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
- const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
-#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
- const auto m1_16 = _mm256_set1_epi16(1);
-#endif
-
- __m256 acc = _mm256_setzero_ps();
-
- // All scales are the same in BitNet!
- uint16_t u = x[0].extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
-
- for (int i = 0; i < nblock; ++i) {
- // We would uncomment this if we wanted to use this implementation for a model that has per block scales
- //uint16_t u = x[i].extra & 0xff;
- //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
- auto signs = _mm256_set1_epi8(x[i].extra >> 8);
- // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
- // To use these to sign the q8 values we need
- // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
- signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
- auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+0), _mm256_shuffle_epi8(signs, shuff3));
- auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+1), _mm256_shuffle_epi8(signs, shuff4));
-
- auto ql = x[i].ql;
- auto qh = x[i].qh;
- auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
- auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
-
- auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
- auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
- auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
- auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
-
- auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
- auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
-
-#if defined __AVX512VNNI__ && defined __AVX512VL__
- dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
- dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
-#else
- dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
- dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
-#endif
-
- // We would uncomment this if we wanted to use this implementation for a model that has per block scales
- //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
- //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
- // All scales are the same for BitNet!
- // This is slower
- //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
- //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
- //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
- //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
- acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), acc);
-
- }
-
- sumf = scale.f * hsum_float_8(acc);
-
-#else
-
uint16_t u = x[0].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
for (int i = 0; i < nblock; ++i) {
@@ -443,9 +376,57 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
sumf += scale.f * (y[i].d) * sumi;
}
-#endif
-
*s = sumf;
+}
+
+void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(nrc);
+
+ static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64");
+
+ constexpr int Nj = QK_IQ1BN/4;
+
+ const block_iq2_bn * x = (const block_iq2_bn *)vx;
+ const block_q8_K64 * y = (const block_q8_K64 *)vy;
+ int nblock = n / QK_IQ1BN;
+
+ float sumf = 0;
+
+ float d = GGML_FP16_TO_FP32(*(const ggml_half *)x);
+ auto * extra_ptr = (const uint16_t *)x;
+ auto extra = extra_ptr[1];
+ auto ql = (const uint8_t *)(extra_ptr + 2);
+ auto qh = ql + QK_IQ1BN/8;
+ auto q8 = y[0].qs;
+ int sumi = 0;
+ for (int k = 0; k < QK_IQ1BN/8; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ int s = 0;
+ for (int j = 0; j < 8; ++j) s += q8[j] * (((val >> 2*j) & 3) - 1);
+ sumi += extra & (1 << k) ? -s : s;
+ q8 += 8;
+ }
+ sumf += y[0].d * sumi;
+
+ for (int i = 1; i < nblock; ++i) {
+ q8 = y[i].qs;
+ int s0 = 0, s1 = 0, s2 = 0, s3 = 0, s4 = 0;
+ for (int j = 0; j < Nj; ++j) {
+ s1 += q8[j+ 0] * (x[i].qs[j] & 0x03);
+ s2 += q8[j+1*Nj] * (x[i].qs[j] & 0x0c);
+ s3 += q8[j+2*Nj] * (x[i].qs[j] & 0x30);
+ s4 += q8[j+3*Nj] * (x[i].qs[j] & 0xc0);
+ s0 += q8[j] + q8[j+1*Nj] + q8[j+2*Nj] + q8[j+3*Nj];
+ }
+ sumf += y[i].d * (s1 + 0.25f*s2 + 0.0625*s3 + 0.015625*s4 - s0);
+ }
+
+ *s = sumf * d;
}
diff --git a/llama.cpp b/llama.cpp
index 1c5027b2..08cb6c35 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -3503,6 +3503,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break;
case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break;
+ case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
@@ -4121,6 +4122,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_BN :return "IQ1_BN - 1.75 bpw Bitnet";
+ case LLAMA_FTYPE_MOSTLY_IQ2_BN :return "IQ2_BN - 2.00 bpw Bitnet";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
@@ -15473,7 +15475,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = GGML_TYPE_IQ3_S;
}
- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN) {
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) {
new_type = GGML_TYPE_IQ4_NL;
}
}
@@ -15673,7 +15675,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
++qs.n_k_quantized;
}
}
- if (new_type == GGML_TYPE_IQ1_BN) {
+ if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN) {
int nx = tensor->ne[0];
if (nx % QK_IQ1BN != 0) {
convert_incompatible_tensor = true;
@@ -15791,6 +15793,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break;
+ case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
diff --git a/llama.h b/llama.h
index 4a03edc1..e630cc66 100644
--- a/llama.h
+++ b/llama.h
@@ -158,6 +158,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_BN = 33,
+ LLAMA_FTYPE_MOSTLY_IQ2_BN = 34,
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};