summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-26 08:13:25 +0200
committerGitHub <noreply@github.com>2025-04-26 08:13:25 +0200
commit715fc552ad2ea5fad38e7ff856bf84fdb71b692e (patch)
treea9999d3d2169222d9be38256de95c3d3859f3b31
parent770892086c15471a397a6a1a196986de906cdc91 (diff)
Add support for Cohere2 (#341)
* Add support for Cohere2 * Fixe IQ4_NL on AVX2 * Command-A needs fp32 precision for K*Q --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c2
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp49
-rw-r--r--src/llama.cpp203
3 files changed, 247 insertions, 7 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index ad9393cc..88013f74 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1289,7 +1289,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
#if GGML_USE_IQK_MULMAT
-#if defined __AVX2__
+#if defined HAVE_FANCY_SIMD
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 45d804a4..e7ab2e5b 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -1750,6 +1750,15 @@ __m256i inline load_iq4nl_values_256() {
return MM256_SET_M128I(val128, val128);
}
+__m128i inline load_iq4k_values_128() {
+ return _mm_loadu_si128((const __m128i *)iq4k_values);
+}
+
+__m256i inline load_iq4k_values_256() {
+ auto val128 = load_iq4k_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
#ifdef HAVE_FANCY_SIMD
//====================================== Zen4 ==================================================
@@ -8519,7 +8528,11 @@ struct Q4_0_1_Dequantizer {
struct IQ4_NL_Dequantizer {
Dequantizer4bit b4;
+#ifdef HAVE_FANCY_SIMD
const __m256i values = load_iq4nl_values_256();
+#else
+ const __m256i values = load_iq4k_values_256();
+#endif
inline __m256i dequant(const block_iq4_nl * x) const {
return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
}
@@ -8630,11 +8643,19 @@ struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_0; }
};
+#ifdef HAVE_FANCY_SIMD
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
+#else
+struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
+ IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK4_NL; }
+};
+#endif
struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -9155,9 +9176,29 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
}
+ else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
+#ifdef HAVE_FANCY_SIMD
+ m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_1_q8_2_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_1_q8_2_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_1_q8_2_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
+#else
+ m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
+#endif
+ }
else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
- std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> ||
- std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
+ std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
@@ -9476,7 +9517,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ4_NL:
assert (ne00 % QK4_NL == 0);
MulMat::set_functions<IQ4_NL_Unpacker>(mm);
+#ifdef HAVE_FANCY_SIMD
expected_typeB = GGML_TYPE_Q8_2_X4;
+#else
+ expected_typeB = GGML_TYPE_Q8_0_X4;
+#endif
break;
case GGML_TYPE_IQ4_NL_R4:
assert (ne00 % QK4_NL == 0);
diff --git a/src/llama.cpp b/src/llama.cpp
index c870b09e..26ddeb2e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -229,6 +229,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_GRANITE = 46,
LLM_ARCH_GRANITE_MOE,
+ LLM_ARCH_COHERE2,
LLM_ARCH_UNKNOWN,
};
@@ -279,6 +280,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
+ { LLM_ARCH_COHERE2, "cohere2" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -1456,7 +1458,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
-
+ {
+ LLM_ARCH_COHERE2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
@@ -2539,6 +2555,7 @@ struct llama_hparams {
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;
+ if (this->n_swa_pattern != other.n_swa_pattern) return false;
if (this->n_embd_head_k != other.n_embd_head_k) return true;
if (this->n_embd_head_v != other.n_embd_head_v) return true;
if (this->n_expert != other.n_expert) return true;
@@ -5797,6 +5814,17 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_COHERE2:
+ {
+ hparams.n_swa_pattern = 4;
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_8B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -6406,6 +6434,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
+ LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@@ -8397,6 +8426,34 @@ static bool llm_load_tensors(
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
}
} break;
+ case LLM_ARCH_COHERE2:
+ {
+ model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+ // output
+ model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+ // init output from the input tok embed
+ model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
+ llama_model_loader::TENSOR_DUPLICATED);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = model.layers[i];
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+ layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
+ layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+ layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+ layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+
+ layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+ layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+ layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+ }
+ }
+ break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -9340,7 +9397,7 @@ static struct ggml_tensor * llm_build_kqv(
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
- (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) {
+ (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
//ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -9364,7 +9421,8 @@ static struct ggml_tensor * llm_build_kqv(
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
+ model.arch == LLM_ARCH_COHERE2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -9423,7 +9481,8 @@ static struct ggml_tensor * llm_build_kqv(
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
+ model.arch == LLM_ARCH_COHERE2) {
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
@@ -15013,6 +15072,137 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_cohere2() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ const float f_logit_scale = hparams.f_logit_scale;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ // cohere2 requires different mask for layers using sliding window (SWA)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+
+ // sliding window switch pattern
+ const int32_t sliding_window_pattern = 4;
+
+ for (int il = 0; il < n_layer; ++il) {
+ // three layers sliding window attention (window size 4096) and ROPE
+ // fourth layer uses global attention without positional embeddings
+ const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+ struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
+ struct ggml_tensor * ffn_inp = cur;
+
+ // self-attention
+ {
+ // rope freq factors for 128k context
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ if (is_sliding) {
+ Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+ beta_fast, beta_slow);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+ attn_factor, beta_fast, beta_slow);
+ cb(Kcur, "Kcur", il);
+ } else {
+ // For non-sliding layers, just reshape without applying RoPE
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ cb(Kcur, "Kcur", il);
+ }
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
+ KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+ }
+
+ struct ggml_tensor * attn_out = cur;
+
+ // feed-forward network
+ {
+ cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+ NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+ cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ // add together residual + FFN + self-attention
+ cur = ggml_add(ctx0, cur, inpL);
+ cur = ggml_add(ctx0, cur, attn_out);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ if (f_logit_scale) {
+ cur = ggml_scale(ctx0, cur, f_logit_scale);
+ }
+
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_t5_encoder() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -15813,6 +16003,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_bitnet_25();
} break;
+ case LLM_ARCH_COHERE2:
+ {
+ result = llm.build_cohere2();
+ } break;
case LLM_ARCH_T5:
{
if (lctx.is_encoding) {
@@ -19486,6 +19680,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
+ case LLM_ARCH_COHERE2:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2