diff options
-rw-r--r-- | convert_hf_to_gguf.py | 28 | ||||
-rw-r--r-- | gguf-py/gguf/constants.py | 26 | ||||
-rw-r--r-- | gguf-py/gguf/tensor_mapping.py | 2 | ||||
-rw-r--r-- | src/llama.cpp | 273 |
4 files changed, 326 insertions, 3 deletions
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 33be63fa..b0a82c80 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3864,6 +3864,34 @@ class JaisModel(Model): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) +@Model.register("Dots1ForCausalLM") +class Dots1Model(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.DOTS1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["num_experts"] = self.hparams["n_routed_experts"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"]) + self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + + if self.hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + else: + raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + if "shared_experts" in name: + return [(self.map_tensor_name(name), data_torch)] + return super().modify_tensors(data_torch, name, bid) + + @Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 489714c4..b3b2bc50 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -226,6 +226,7 @@ class MODEL_ARCH(IntEnum): T5 = auto() T5ENCODER = auto() JAIS = auto() + DOTS1 = auto() class MODEL_TENSOR(IntEnum): @@ -362,6 +363,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.DOTS1: "dots1", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1164,6 +1166,30 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.DOTS1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9688b02c..d507725c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -257,7 +257,7 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_EXP_PROBS_B: ( - "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 ), # Feed-forward up diff --git a/src/llama.cpp b/src/llama.cpp index 92403f6a..8e6c66d3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -235,6 +235,7 @@ enum llm_arch { LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_COHERE2, + LLM_ARCH_DOTS1, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_UNKNOWN, }; @@ -292,6 +293,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1598,6 +1600,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_DOTS1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, + { LLM_ARCH_HUNYUAN_MOE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -1663,6 +1693,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_BITNET, + LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -2580,6 +2611,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_142B, MODEL_236B, MODEL_314B, MODEL_405B, @@ -5214,6 +5246,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_142B: return "142B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; case MODEL_405B: return "405B"; @@ -6066,6 +6099,20 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DOTS1: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + switch (hparams.n_layer) { + case 62: model.type = e_model::MODEL_142B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_HUNYUAN_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6170,7 +6217,12 @@ static void llm_load_vocab( } // default special tokens - vocab.special_bos_id = 11; + if(model.arch == LLM_ARCH_DOTS1) { + vocab.special_bos_id = -1; + } + else { + vocab.special_bos_id = 11; + } vocab.special_eos_id = 11; vocab.special_unk_id = -1; vocab.special_sep_id = -1; @@ -9208,6 +9260,54 @@ static bool llm_load_tensors( layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_DOTS1: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + // MoE branch + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; case LLM_ARCH_HUNYUAN_MOE: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16948,6 +17048,153 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_dots1() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_hunyuan_moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -17198,7 +17445,7 @@ static struct ggml_cgraph * llama_build_graph( const llama_vocab * vocab = llama_get_vocab(&lctx); llama_token bos = llama_token_bos_impl(*vocab); llama_token eos = llama_token_eos_impl(*vocab); - bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos); + bool is_warming_up = (batch.n_tokens == 1 && (batch.token[0] == ((bos != -1) ? bos : eos))); struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up); llm.init(); @@ -17394,6 +17641,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_jais(); } break; + case LLM_ARCH_DOTS1: + { + result = llm.build_dots1(); + } break; case LLM_ARCH_HUNYUAN_MOE: { result = llm.build_hunyuan_moe(); @@ -21170,6 +21421,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: + case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: return LLAMA_ROPE_TYPE_NEOX; @@ -22984,6 +23236,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MEGREZ; } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { return LLM_CHAT_TEMPLATE_LLAMA4; + } else if (tmpl_contains("<|endofuserprompt|>")) { + return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } @@ -23404,6 +23658,21 @@ static int32_t llama_chat_apply_template_internal( ss << message->content; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) { + // dots.llm1.inst (DOTS1) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>" << message->content << "<|endofsystem|>"; + } else if (role == "user") { + ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>"; + } else { + ss << "<|response|>" << message->content << "<|endofresponse|>"; + } + } + if (add_ass) { + ss << "<|response|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { // tencent/Hunyuan-A13B-Instruct for (auto message : chat) { |