diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-16 14:25:12 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:51 +0300 |
commit | f20b28558bdd20454ce891d36db5f37de819025a (patch) | |
tree | e9b54a2086cd0e4fb35d9dede9822bdbf3d6bc10 | |
parent | 58756ef03ff3f19a98187395d12af3f19f121f90 (diff) |
bitnet: python + llama
-rwxr-xr-x | convert-hf-to-gguf.py | 29 | ||||
-rw-r--r-- | gguf-py/gguf/constants.py | 24 | ||||
-rw-r--r-- | gguf-py/gguf/tensor_mapping.py | 8 | ||||
-rw-r--r-- | iqk-quantize.cpp | 437 | ||||
-rw-r--r-- | llama.cpp | 281 |
5 files changed, 779 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a6751cc8..0c08b800 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1400,6 +1400,35 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # transform weight into 1/0/-1 (in fp32) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = self.weight_quant(data_torch) + + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fb20cfab..4cc3e35f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + BITNET = auto() class MODEL_TENSOR(IntEnum): @@ -200,6 +201,8 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -237,6 +240,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.BITNET: "bitnet", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -288,6 +292,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -808,6 +814,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 81b4992a..350035bd 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -413,6 +413,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_KV_A_NORM: ( "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 ), + + MODEL_TENSOR.ATTN_SUB_NORM: ( + "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + ), + + MODEL_TENSOR.FFN_SUB_NORM: ( + "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + ), } # architecture-specific block mappings diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp new file mode 100644 index 00000000..8b071a3e --- /dev/null +++ b/iqk-quantize.cpp @@ -0,0 +1,437 @@ +#include "ggml-quants.h" +#include "ggml-impl.h" +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include <vector> +#include <utility> +#include <cstdint> +#include <cmath> +#include <array> +#include <algorithm> +#include <cstring> +#include <mutex> + +namespace { + +inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +struct IQ1BNData { + IQ1BNData(); + std::vector<std::pair<int16_t, bool>> map; + std::vector<uint16_t> rmap; +}; + +const IQ1BNData& get_iq1bn_data() { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static IQ1BNData iq1bn; + return iq1bn; +} + +IQ1BNData::IQ1BNData() { + map.resize(1 << 16, {int16_t(-1), false}); + uint64_t aux64; + uint8_t * aux8 = (uint8_t *)&aux64; + std::vector<uint64_t> values; + values.reserve(6561); + rmap.reserve(6561); + for (int i = 0; i < (1 << 16); ++i) { + bool is_good = true; + for (int j = 0; j < 8; ++j) { + aux8[j] = (i >> 2*j) & 3; + if (aux8[j] == 3u) { is_good = false; break; } + } + if (!is_good) continue; + auto orig = aux64; + for (int j = 0; j < 8; ++j) aux8[j] = 2 - aux8[j]; + int k = 0; + for (; k < int(values.size()); ++k) { + if (values[k] == aux64) break; + } + if (k < int(values.size())) { + map[i] = {k, true}; + } else { + map[i].first = values.size(); + values.push_back(orig); + rmap.push_back(i); + } + } + printf("==================== %s: initialized %d grid points\n", __func__, int(rmap.size())); +} + +struct IQ1BNQuantizer { + typedef union { + float f; + uint32_t i; + } scale_t; + constexpr static int block_size = QK_IQ1BN; + int8_t L[QK_IQ1BN]; + void quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix); +}; + +void IQ1BNQuantizer::quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) { + + (void)imatrix; + + constexpr int Nk = block_size/8; + + const int nblock = n_per_row/QK_IQ1BN; + + const auto& iq1bn = get_iq1bn_data(); + + float max_in_row = 0; + for (int j = 0; j < n_per_row; ++j) { + float ax = fabsf(src[j]); + max_in_row = std::max(max_in_row, ax); + } + + max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation + scale_t s; + uint8_t u = 0; + if (max_in_row > 1.9074e-06f && max_in_row < 0.12109f) { + s.f = max_in_row; + u = ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf); + s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + } else { + // outside the allowed range. Small values we can habdle via quants set to zero, so we only warn about too large values + if (max_in_row >= 0.12109f) { + u = 255; + fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row); + } else{ + u = 0; + } + } + + for (int ib = 0; ib < nblock; ++ib) { + std::memset(&y[ib], 0, sizeof(block_iq1_bn)); + auto xb = src + QK_IQ1BN*ib; + for (int j = 0; j < QK_IQ1BN; ++j) { + L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2; + } + auto ql = y[ib].ql; + auto qh = y[ib].qh; + uint16_t extra = 0; + for (int k = 0; k < Nk; ++k) { + auto Lk = L + 8*k; + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j); + auto& val = iq1bn.map[u]; + GGML_ASSERT(val.first >= 0); + ql[k] = val.first & 255; + qh[k/2] |= (val.first >> 8) << 4*(k%2); + if (val.second) extra |= (1 << k); + } + + y[ib].extra = u | (extra << 8); + + } +} +} + +void iq1bn_init_impl(void) { + get_iq1bn_data(); +} + +size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + IQ1BNQuantizer iq1bn; + int nblock = n_per_row/QK_IQ1BN; + block_iq1_bn * y = (block_iq1_bn *)dst; + for (int row = 0; row < nrows; ++row) { + iq1bn.quantize_one_row(src + row*n_per_row, y, n_per_row, imatrix); + y += nblock; + } + return sizeof(block_iq1_bn)*nblock*nrows; +} + +void quantize_row_iq1_bn_reference(const float * x, block_iq1_bn * y, int64_t k) { + quantize_iq1_bn(x, y, 1, k, nullptr); +} + +void quantize_row_iq1_bn(const float * x, void * y, int64_t k) { + quantize_iq1_bn(x, y, 1, k, nullptr); +} + +void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { + assert(k%QK_IQ1BN == 0); + int nblock = k / QK_IQ1BN; + + IQ1BNQuantizer::scale_t s; + + for (int i = 0; i < nblock; ++i) { + uint16_t u = x[i].extra & 0xff; + s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d = s.f; + uint8_t extra = x[i].extra >> 8; + auto qh = x[i].qh; + auto ql = x[i].ql; + for (int k = 0; k < QK_IQ1BN/8; ++k) { + uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); + uint16_t val = iq1bn_grid_u16[idx]; + float dls = extra & (1 << k) ? -d : d; + for (int j = 0; j < 8; ++j) y[j] = dls * (((val >> 2*j) & 3) - 1); + y += 8; + } + } +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} +#endif + +void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(nrc); + + static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64"); + + const block_iq1_bn * x = (const block_iq1_bn *)vx; + const block_q8_0 * y = (const block_q8_0 *)vy; + int nblock = n / QK_IQ1BN; + + float sumf = 0; + IQ1BNQuantizer::scale_t scale; + +#if defined __AVX2__ + + const auto m1_8 = _mm256_set1_epi8(1); + const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000); + const auto shuff2 = _mm256_add_epi8(shuff1, m1_8); + const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); + const auto mask1 = _mm256_set1_epi64x(0x8040201008040201); +#if !(defined __AVX512VNNI__ && defined __AVX512VL__) + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + + // All scales are the same in BitNet! + uint16_t u = x[0].extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + + for (int i = 0; i < nblock; ++i) { + // We would uncomment this if we wanted to use this implementation for a model that has per block scales + //uint16_t u = x[i].extra & 0xff; + //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + auto signs = _mm256_set1_epi8(x[i].extra >> 8); + // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ... + // To use these to sign the q8 values we need + // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7 + signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8); + auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+0].qs), _mm256_shuffle_epi8(signs, shuff3)); + auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+1].qs), _mm256_shuffle_epi8(signs, shuff4)); + + auto ql = x[i].ql; + auto qh = x[i].qh; + auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); + auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]); + + auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1); + auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1); + auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1); + auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1); + + auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p)); + auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p)); + +#if defined __AVX512VNNI__ && defined __AVX512VL__ + dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1); + dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2); +#else + dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1)); + dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2)); +#endif + + // We would uncomment this if we wanted to use this implementation for a model that has per block scales + //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1); + //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2); + // All scales are the same for BitNet! + // This is slower + //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16); + //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32)); + //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1); + //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1); + acc2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2); + + } + + //sumf = hsum_float_8(_mm256_add_ps(acc1, acc2)); + sumf = scale.f * hsum_float_8(_mm256_add_ps(acc1, acc2)); + +#else + + for (int i = 0; i < nblock; ++i) { + uint16_t u = x[i].extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + uint8_t extra = x[i].extra >> 8; + auto qh = x[i].qh; + auto ql = x[i].ql; + auto q8 = y[2*i+0].qs; + int16_t sumi1 = 0; + for (int k = 0; k < 4; ++k) { + uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); + uint16_t val = iq1bn_grid_u16[idx]; + int16_t sl = 0; + for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); + sumi1 += extra & (1 << k) ? -sl : sl; + q8 += 8; + } + q8 = y[2*i+1].qs; + int16_t sumi2 = 0; + for (int k = 4; k < 8; ++k) { + uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); + uint16_t val = iq1bn_grid_u16[idx]; + int16_t sl = 0; + for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); + sumi2 += extra & (1 << k) ? -sl : sl; + q8 += 8; + } + sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2); + } + +#endif + + *s = sumf; + +} + +void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(nrc); + + static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64"); + + const block_iq1_bn * x = (const block_iq1_bn *)vx; + const block_q8_K64 * y = (const block_q8_K64 *)vy; + int nblock = n / QK_IQ1BN; + + float sumf = 0; + IQ1BNQuantizer::scale_t scale; + +#if defined __AVX2__ + + const auto m1_8 = _mm256_set1_epi8(1); + const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000); + const auto shuff2 = _mm256_add_epi8(shuff1, m1_8); + const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); + const auto mask1 = _mm256_set1_epi64x(0x8040201008040201); +#if !(defined __AVX512VNNI__ && defined __AVX512VL__) + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + __m256 acc = _mm256_setzero_ps(); + + // All scales are the same in BitNet! + uint16_t u = x[0].extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + + for (int i = 0; i < nblock; ++i) { + // We would uncomment this if we wanted to use this implementation for a model that has per block scales + //uint16_t u = x[i].extra & 0xff; + //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + auto signs = _mm256_set1_epi8(x[i].extra >> 8); + // signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ... + // To use these to sign the q8 values we need + // 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7 + signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8); + auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+0), _mm256_shuffle_epi8(signs, shuff3)); + auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+1), _mm256_shuffle_epi8(signs, shuff4)); + + auto ql = x[i].ql; + auto qh = x[i].qh; + auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); + auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]); + + auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1); + auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1); + auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1); + auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1); + + auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p)); + auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p)); + +#if defined __AVX512VNNI__ && defined __AVX512VL__ + dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1); + dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2); +#else + dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1)); + dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2)); +#endif + + // We would uncomment this if we wanted to use this implementation for a model that has per block scales + //acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1); + //acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2); + // All scales are the same for BitNet! + // This is slower + //uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16); + //auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32)); + //acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1); + //acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2); + acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), acc); + + } + + sumf = scale.f * hsum_float_8(acc); + +#else + + for (int i = 0; i < nblock; ++i) { + uint16_t u = x[i].extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + uint8_t extra = x[i].extra >> 8; + auto qh = x[i].qh; + auto ql = x[i].ql; + auto q8 = y[2*i+0].qs; + int16_t sumi1 = 0; + for (int k = 0; k < 4; ++k) { + uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); + uint16_t val = iq1bn_grid_u16[idx]; + int16_t sl = 0; + for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); + sumi1 += extra & (1 << k) ? -sl : sl; + q8 += 8; + } + q8 = y[2*i+1].qs; + int16_t sumi2 = 0; + for (int k = 4; k < 8; ++k) { + uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00); + uint16_t val = iq1bn_grid_u16[idx]; + int16_t sl = 0; + for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); + sumi2 += extra & (1 << k) ? -sl : sl; + q8 += 8; + } + sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2); + } + +#endif + + *s = sumf; + +} @@ -225,6 +225,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_BITNET, LLM_ARCH_UNKNOWN, }; @@ -263,6 +264,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -500,6 +502,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, }; static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = { @@ -1114,6 +1118,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { LLM_ARCH_UNKNOWN, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -2118,6 +2140,8 @@ struct llama_layer { struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * attn_q_a_norm; struct ggml_tensor * attn_kv_a_norm; + struct ggml_tensor * attn_sub_norm; + struct ggml_tensor * ffn_sub_norm; // attention struct ggml_tensor * wq; @@ -4710,6 +4734,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6655,6 +6688,40 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_BITNET: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + } + + const uint32_t n_ff = hparams.n_ff; + model.layers.resize(n_layer); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -11709,6 +11776,215 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + struct ggml_tensor * q_cur = Qcur; + struct ggml_tensor * kq_mask = KQ_mask; + float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; + struct ggml_cgraph * graph = gf; + struct ggml_tensor * wo = model.layers[il].wo; + struct ggml_tensor * cur_attn; + struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + cb(q, "q", il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + if (cparams.flash_attn) { + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + 0); + cb(v, "v", il); + + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + + cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); + } + + cur_attn = llm_build_norm(ctx0, cur_attn, hparams, + attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur_attn, "attn_sub_norm", il); + + ggml_build_forward_expand(graph, cur_attn); + + cur = ggml_mul_mat(ctx0, wo, cur_attn); + + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + + cb(tmp, "ffn_up", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + + cb(cur, "ffn_gate", il); + + + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cb(cur, "ffn_down", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { @@ -11932,6 +12208,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_BITNET: + { + result = llm.build_bitnet(); + } break; default: GGML_ASSERT(false); } @@ -16760,6 +17040,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: |