summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xconvert-hf-to-gguf.py29
-rw-r--r--gguf-py/gguf/constants.py24
-rw-r--r--gguf-py/gguf/tensor_mapping.py8
-rw-r--r--iqk-quantize.cpp437
-rw-r--r--llama.cpp281
5 files changed, 779 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index a6751cc8..0c08b800 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1400,6 +1400,35 @@ class LlamaModel(Model):
raise ValueError(f"Unprocessed experts: {experts}")
+@Model.register("BitnetForCausalLM")
+class BitnetModel(Model):
+ model_arch = gguf.MODEL_ARCH.BITNET
+
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(1.0)
+
+ def weight_quant(self, weight):
+ dtype = weight.dtype
+ weight = weight.float()
+ s = 1 / weight.abs().mean().clamp(min=1e-5)
+ result = (weight * s).round().clamp(-1, 1) / s
+ return result.type(dtype)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # transform weight into 1/0/-1 (in fp32)
+ if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
+ "down_proj.weight", "up_proj.weight", "gate_proj.weight",
+ "o_proj.weight")):
+ data_torch = self.weight_quant(data_torch)
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+
@Model.register("GrokForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index fb20cfab..4cc3e35f 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum):
OLMO = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
+ BITNET = auto()
class MODEL_TENSOR(IntEnum):
@@ -200,6 +201,8 @@ class MODEL_TENSOR(IntEnum):
ATTN_KV_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
+ FFN_SUB_NORM = auto()
+ ATTN_SUB_NORM = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -237,6 +240,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
+ MODEL_ARCH.BITNET: "bitnet",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -288,6 +292,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
+ MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
+ MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -808,6 +814,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
+ MODEL_ARCH.BITNET: [
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_SUB_NORM,
+ MODEL_TENSOR.FFN_SUB_NORM,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 81b4992a..350035bd 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -413,6 +413,14 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_KV_A_NORM: (
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
),
+
+ MODEL_TENSOR.ATTN_SUB_NORM: (
+ "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
+ ),
+
+ MODEL_TENSOR.FFN_SUB_NORM: (
+ "model.layers.{bid}.mlp.ffn_layernorm", # bitnet
+ ),
}
# architecture-specific block mappings
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp
new file mode 100644
index 00000000..8b071a3e
--- /dev/null
+++ b/iqk-quantize.cpp
@@ -0,0 +1,437 @@
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include <vector>
+#include <utility>
+#include <cstdint>
+#include <cmath>
+#include <array>
+#include <algorithm>
+#include <cstring>
+#include <mutex>
+
+namespace {
+
+inline int nearest_int(float fval) {
+ assert(fval <= 4194303.f);
+ float val = fval + 12582912.f;
+ int i; memcpy(&i, &val, sizeof(int));
+ return (i & 0x007fffff) - 0x00400000;
+}
+
+struct IQ1BNData {
+ IQ1BNData();
+ std::vector<std::pair<int16_t, bool>> map;
+ std::vector<uint16_t> rmap;
+};
+
+const IQ1BNData& get_iq1bn_data() {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static IQ1BNData iq1bn;
+ return iq1bn;
+}
+
+IQ1BNData::IQ1BNData() {
+ map.resize(1 << 16, {int16_t(-1), false});
+ uint64_t aux64;
+ uint8_t * aux8 = (uint8_t *)&aux64;
+ std::vector<uint64_t> values;
+ values.reserve(6561);
+ rmap.reserve(6561);
+ for (int i = 0; i < (1 << 16); ++i) {
+ bool is_good = true;
+ for (int j = 0; j < 8; ++j) {
+ aux8[j] = (i >> 2*j) & 3;
+ if (aux8[j] == 3u) { is_good = false; break; }
+ }
+ if (!is_good) continue;
+ auto orig = aux64;
+ for (int j = 0; j < 8; ++j) aux8[j] = 2 - aux8[j];
+ int k = 0;
+ for (; k < int(values.size()); ++k) {
+ if (values[k] == aux64) break;
+ }
+ if (k < int(values.size())) {
+ map[i] = {k, true};
+ } else {
+ map[i].first = values.size();
+ values.push_back(orig);
+ rmap.push_back(i);
+ }
+ }
+ printf("==================== %s: initialized %d grid points\n", __func__, int(rmap.size()));
+}
+
+struct IQ1BNQuantizer {
+ typedef union {
+ float f;
+ uint32_t i;
+ } scale_t;
+ constexpr static int block_size = QK_IQ1BN;
+ int8_t L[QK_IQ1BN];
+ void quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
+};
+
+void IQ1BNQuantizer::quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
+
+ (void)imatrix;
+
+ constexpr int Nk = block_size/8;
+
+ const int nblock = n_per_row/QK_IQ1BN;
+
+ const auto& iq1bn = get_iq1bn_data();
+
+ float max_in_row = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ float ax = fabsf(src[j]);
+ max_in_row = std::max(max_in_row, ax);
+ }
+
+ max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation
+ scale_t s;
+ uint8_t u = 0;
+ if (max_in_row > 1.9074e-06f && max_in_row < 0.12109f) {
+ s.f = max_in_row;
+ u = ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf);
+ s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ } else {
+ // outside the allowed range. Small values we can habdle via quants set to zero, so we only warn about too large values
+ if (max_in_row >= 0.12109f) {
+ u = 255;
+ fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row);
+ } else{
+ u = 0;
+ }
+ }
+
+ for (int ib = 0; ib < nblock; ++ib) {
+ std::memset(&y[ib], 0, sizeof(block_iq1_bn));
+ auto xb = src + QK_IQ1BN*ib;
+ for (int j = 0; j < QK_IQ1BN; ++j) {
+ L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
+ }
+ auto ql = y[ib].ql;
+ auto qh = y[ib].qh;
+ uint16_t extra = 0;
+ for (int k = 0; k < Nk; ++k) {
+ auto Lk = L + 8*k;
+ uint16_t u = 0;
+ for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j);
+ auto& val = iq1bn.map[u];
+ GGML_ASSERT(val.first >= 0);
+ ql[k] = val.first & 255;
+ qh[k/2] |= (val.first >> 8) << 4*(k%2);
+ if (val.second) extra |= (1 << k);
+ }
+
+ y[ib].extra = u | (extra << 8);
+
+ }
+}
+}
+
+void iq1bn_init_impl(void) {
+ get_iq1bn_data();
+}
+
+size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ IQ1BNQuantizer iq1bn;
+ int nblock = n_per_row/QK_IQ1BN;
+ block_iq1_bn * y = (block_iq1_bn *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ iq1bn.quantize_one_row(src + row*n_per_row, y, n_per_row, imatrix);
+ y += nblock;
+ }
+ return sizeof(block_iq1_bn)*nblock*nrows;
+}
+
+void quantize_row_iq1_bn_reference(const float * x, block_iq1_bn * y, int64_t k) {
+ quantize_iq1_bn(x, y, 1, k, nullptr);
+}
+
+void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
+ quantize_iq1_bn(x, y, 1, k, nullptr);
+}
+
+void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
+ assert(k%QK_IQ1BN == 0);
+ int nblock = k / QK_IQ1BN;
+
+ IQ1BNQuantizer::scale_t s;
+
+ for (int i = 0; i < nblock; ++i) {
+ uint16_t u = x[i].extra & 0xff;
+ s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ float d = s.f;
+ uint8_t extra = x[i].extra >> 8;
+ auto qh = x[i].qh;
+ auto ql = x[i].ql;
+ for (int k = 0; k < QK_IQ1BN/8; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ float dls = extra & (1 << k) ? -d : d;
+ for (int j = 0; j < 8; ++j) y[j] = dls * (((val >> 2*j) & 3) - 1);
+ y += 8;
+ }
+ }
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+ __m128 res = _mm256_extractf128_ps(x, 1);
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
+ return _mm_cvtss_f32(res);
+}
+#endif
+
+void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(nrc);
+
+ static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
+
+ const block_iq1_bn * x = (const block_iq1_bn *)vx;
+ const block_q8_0 * y = (const block_q8_0 *)vy;
+ int nblock = n / QK_IQ1BN;
+
+ float sumf = 0;
+ IQ1BNQuantizer::scale_t scale;
+
+#if defined __AVX2__
+
+ const auto m1_8 = _mm256_set1_epi8(1);
+ const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
+ const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
+ const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
+ const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
+#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ __m256 acc1 = _mm256_setzero_ps();
+ __m256 acc2 = _mm256_setzero_ps();
+
+ // All scales are the same in BitNet!
+ uint16_t u = x[0].extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+
+ for (int i = 0; i < nblock; ++i) {
+ // We would uncomment this if we wanted to use this implementation for a model that has per block scales
+ //uint16_t u = x[i].extra & 0xff;
+ //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ auto signs = _mm256_set1_epi8(x[i].extra >> 8);
+ // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
+ // To use these to sign the q8 values we need
+ // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
+ signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
+ auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+0].qs), _mm256_shuffle_epi8(signs, shuff3));
+ auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+1].qs), _mm256_shuffle_epi8(signs, shuff4));
+
+ auto ql = x[i].ql;
+ auto qh = x[i].qh;
+ auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
+ auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
+
+ auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
+ auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
+ auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
+ auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
+
+ auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
+ auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
+
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
+ dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
+#else
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
+ dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
+#endif
+
+ // We would uncomment this if we wanted to use this implementation for a model that has per block scales
+ //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
+ //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
+ // All scales are the same for BitNet!
+ // This is slower
+ //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
+ //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
+ //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
+ //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
+ acc1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
+ acc2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
+
+ }
+
+ //sumf = hsum_float_8(_mm256_add_ps(acc1, acc2));
+ sumf = scale.f * hsum_float_8(_mm256_add_ps(acc1, acc2));
+
+#else
+
+ for (int i = 0; i < nblock; ++i) {
+ uint16_t u = x[i].extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ uint8_t extra = x[i].extra >> 8;
+ auto qh = x[i].qh;
+ auto ql = x[i].ql;
+ auto q8 = y[2*i+0].qs;
+ int16_t sumi1 = 0;
+ for (int k = 0; k < 4; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ int16_t sl = 0;
+ for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
+ sumi1 += extra & (1 << k) ? -sl : sl;
+ q8 += 8;
+ }
+ q8 = y[2*i+1].qs;
+ int16_t sumi2 = 0;
+ for (int k = 4; k < 8; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ int16_t sl = 0;
+ for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
+ sumi2 += extra & (1 << k) ? -sl : sl;
+ q8 += 8;
+ }
+ sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
+ }
+
+#endif
+
+ *s = sumf;
+
+}
+
+void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(nrc);
+
+ static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
+
+ const block_iq1_bn * x = (const block_iq1_bn *)vx;
+ const block_q8_K64 * y = (const block_q8_K64 *)vy;
+ int nblock = n / QK_IQ1BN;
+
+ float sumf = 0;
+ IQ1BNQuantizer::scale_t scale;
+
+#if defined __AVX2__
+
+ const auto m1_8 = _mm256_set1_epi8(1);
+ const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
+ const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
+ const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
+ const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
+#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ __m256 acc = _mm256_setzero_ps();
+
+ // All scales are the same in BitNet!
+ uint16_t u = x[0].extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+
+ for (int i = 0; i < nblock; ++i) {
+ // We would uncomment this if we wanted to use this implementation for a model that has per block scales
+ //uint16_t u = x[i].extra & 0xff;
+ //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ auto signs = _mm256_set1_epi8(x[i].extra >> 8);
+ // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
+ // To use these to sign the q8 values we need
+ // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
+ signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
+ auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+0), _mm256_shuffle_epi8(signs, shuff3));
+ auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+1), _mm256_shuffle_epi8(signs, shuff4));
+
+ auto ql = x[i].ql;
+ auto qh = x[i].qh;
+ auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
+ auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
+
+ auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
+ auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
+ auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
+ auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
+
+ auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
+ auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
+
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
+ dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
+#else
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
+ dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
+#endif
+
+ // We would uncomment this if we wanted to use this implementation for a model that has per block scales
+ //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
+ //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
+ // All scales are the same for BitNet!
+ // This is slower
+ //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
+ //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
+ //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
+ //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), acc);
+
+ }
+
+ sumf = scale.f * hsum_float_8(acc);
+
+#else
+
+ for (int i = 0; i < nblock; ++i) {
+ uint16_t u = x[i].extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ uint8_t extra = x[i].extra >> 8;
+ auto qh = x[i].qh;
+ auto ql = x[i].ql;
+ auto q8 = y[2*i+0].qs;
+ int16_t sumi1 = 0;
+ for (int k = 0; k < 4; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ int16_t sl = 0;
+ for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
+ sumi1 += extra & (1 << k) ? -sl : sl;
+ q8 += 8;
+ }
+ q8 = y[2*i+1].qs;
+ int16_t sumi2 = 0;
+ for (int k = 4; k < 8; ++k) {
+ uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
+ uint16_t val = iq1bn_grid_u16[idx];
+ int16_t sl = 0;
+ for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
+ sumi2 += extra & (1 << k) ? -sl : sl;
+ q8 += 8;
+ }
+ sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
+ }
+
+#endif
+
+ *s = sumf;
+
+}
diff --git a/llama.cpp b/llama.cpp
index a05a52b4..619ffa4e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -225,6 +225,7 @@ enum llm_arch {
LLM_ARCH_OLMO,
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK2,
+ LLM_ARCH_BITNET,
LLM_ARCH_UNKNOWN,
};
@@ -263,6 +264,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
+ { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -500,6 +502,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
+ LLM_TENSOR_ATTN_SUB_NORM,
+ LLM_TENSOR_FFN_SUB_NORM,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -1114,6 +1118,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_BITNET,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -2118,6 +2140,8 @@ struct llama_layer {
struct ggml_tensor * attn_out_norm_b;
struct ggml_tensor * attn_q_a_norm;
struct ggml_tensor * attn_kv_a_norm;
+ struct ggml_tensor * attn_sub_norm;
+ struct ggml_tensor * ffn_sub_norm;
// attention
struct ggml_tensor * wq;
@@ -4710,6 +4734,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_BITNET:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 26: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -6655,6 +6688,40 @@ static bool llm_load_tensors(
}
}
} break;
+ case LLM_ARCH_BITNET:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ }
+
+ const uint32_t n_ff = hparams.n_ff;
+ model.layers.resize(n_layer);
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
+
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -11709,6 +11776,215 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_bitnet() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ // B1.K
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ // B1.V
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+
+ const int64_t n_ctx = cparams.n_ctx;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+ struct ggml_tensor * q_cur = Qcur;
+ struct ggml_tensor * kq_mask = KQ_mask;
+ float kq_scale = 1.0f/sqrtf(float(n_embd_head));
+ struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
+ struct ggml_cgraph * graph = gf;
+ struct ggml_tensor * wo = model.layers[il].wo;
+ struct ggml_tensor * cur_attn;
+ struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+ cb(q, "q", il);
+
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx0, kv_self.k_l[il],
+ n_embd_head_k, n_kv, n_head_kv,
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+ 0);
+ cb(k, "k", il);
+
+ if (cparams.flash_attn) {
+
+ // split cached v into n_head heads (not transposed)
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_embd_head_v, n_kv, n_head_kv,
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
+ 0);
+ cb(v, "v", il);
+
+ cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+
+ cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+ } else {
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ GGML_ASSERT(kv_self.size == n_ctx);
+
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+
+ cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur_attn, "kqv_merged_cont", il);
+ }
+
+ cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
+ attn_sub_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur_attn, "attn_sub_norm", il);
+
+ ggml_build_forward_expand(graph, cur_attn);
+
+ cur = ggml_mul_mat(ctx0, wo, cur_attn);
+
+ cb(cur, "kqv_out", il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward forward
+ if (model.layers[il].ffn_gate_inp == nullptr) {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
+
+ cb(tmp, "ffn_up", il);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
+
+ cb(cur, "ffn_gate", il);
+
+
+ cur = ggml_silu(ctx0, cur);
+ cb(cur, "ffn_silu", il);
+
+ cur = ggml_mul(ctx0, cur, tmp);
+ cb(cur, "ffn_gate_par", il);
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_sub_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_sub_norm", il);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
+ cb(cur, "ffn_down", il);
+ }
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+ return gf;
+ }
+
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -11932,6 +12208,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_deepseek2();
} break;
+ case LLM_ARCH_BITNET:
+ {
+ result = llm.build_bitnet();
+ } break;
default:
GGML_ASSERT(false);
}
@@ -16760,6 +17040,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_STABLELM:
+ case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE: