summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsaood06 <saood05@gmail.com>2025-05-09 02:09:59 -0500
committerGitHub <noreply@github.com>2025-05-09 10:09:59 +0300
commitbc6ae515ceb14eeaf198e00251a9689539cea176 (patch)
tree82ca4aa8afa0dce38e0f3cb6fa9c7a78ec0065d2
parent4084ca7331611da4426d781a15a6ffa68312759e (diff)
Support for Llama-3-Nemotron models (#377)
* conflict resolution * Changes to make work and add longrope support * Changes to n_attention_wv rule * Untested support of 253B * DeciLMCausalModel now reads rope_theta from config.json properly * Remove errant Granite mentions * Better n_attention_vw rule * Update vocab.py --------- Co-authored-by: Yee Man Chan <ymchan@gmail.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rwxr-xr-xconvert_hf_to_gguf.py178
-rw-r--r--gguf-py/gguf/constants.py26
-rw-r--r--gguf-py/gguf/tensor_mapping.py3
-rw-r--r--gguf-py/gguf/vocab.py33
-rw-r--r--include/llama.h3
-rw-r--r--src/llama.cpp288
6 files changed, 523 insertions, 8 deletions
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 20d27a5c..16f97ab0 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -1597,6 +1597,184 @@ class LlamaModel(Model):
raise ValueError(f"Unprocessed experts: {experts}")
+@Model.register("DeciLMForCausalLM")
+class DeciModel(Model):
+ model_arch = gguf.MODEL_ARCH.DECI
+
+ @staticmethod
+ def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
+ # DeciLM-specific code
+ intermediate_size = int(2 * ffn_mult * n_embd / 3)
+ return DeciModel._find_multiple(intermediate_size, 256)
+
+ @staticmethod
+ def _find_multiple(n: int, k: int) -> int:
+ # DeciLM-specific code
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
+ _block_configs: list[dict[str,Any]] = self.hparams["block_configs"]
+ assert self.block_count == len(_block_configs)
+ self._num_kv_heads = list()
+ self._num_heads = list()
+ _ffn_multipliers = list()
+ # ***linear attention layer***
+ # if n_heads_in_group is None and replace_with_linear is True
+ # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads
+ # ***attention-free layer***
+ # if n_heads_in_group is None and replace_with_linear is False
+ # then _num_kv_heads[il] is 0 and _num_heads[il] is 0
+ # ***normal attention-layer***
+ # if n_heads_in_group is not None, then
+ # _num_kv_heads[il] is num_attention_head // n_heads_in_group and
+ # _num_heads[il] is num_attention_head
+ # ***dummy layer*** for nemotron 253B
+ # if n_heads_in_group is None and ffn_mult is None
+ # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0
+ for il in range(len(_block_configs)):
+ if _block_configs[il]["attention"]["n_heads_in_group"] is None:
+ if _block_configs[il]["attention"]["replace_with_linear"] is True:
+ self._num_kv_heads.append(0)
+ self._num_heads.append(self.hparams["num_attention_heads"])
+ else:
+ self._num_kv_heads.append(0)
+ self._num_heads.append(0)
+ else:
+ self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"])
+ self._num_heads.append(self.hparams["num_attention_heads"])
+ if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer
+ _ffn_multipliers.append(0.0)
+ else:
+ _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"])
+ assert self.block_count == len(self._num_kv_heads)
+ assert self.block_count == len(self._num_heads)
+ assert self.block_count == len(_ffn_multipliers)
+ assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
+ assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int)
+ assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float)
+ self._ffn_dims: list[int] = [
+ DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"])
+ for multiplier in _ffn_multipliers
+ ]
+
+ def set_vocab(self):
+ # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's
+ # eos_token from '|eot_id|' to '|end_of_text|'
+ if self.hparams.get("vocab_size", 128256) == 128256:
+ tokens, toktypes, tokpre = self.get_vocab_base()
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+ special_vocab.add_to_gguf(self.gguf_writer)
+ else:
+ # DeciLM-7B
+ self._set_vocab_llama_hf()
+
+ def set_gguf_parameters(self):
+ if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
+ assert self.block_count == len(self._num_kv_heads)
+ assert self.block_count == len(self._num_heads)
+ assert self.block_count == len(self._ffn_dims)
+ if (rope_theta := self.hparams.get("rope_theta")) is not None:
+ self.gguf_writer.add_rope_freq_base(rope_theta)
+ self.gguf_writer.add_head_count_kv(self._num_kv_heads)
+ self.gguf_writer.add_head_count(self._num_heads)
+ self.gguf_writer.add_feed_forward_length(self._ffn_dims)
+ self.gguf_writer.add_block_count(self.block_count)
+ self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
+ self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+ self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+ self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+ self.gguf_writer.add_file_type(self.ftype)
+ else: # DeciLM-7B
+ super().set_gguf_parameters()
+ if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B
+ self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"]
+ assert self.block_count == len(self._num_kv_heads)
+ self.gguf_writer.add_head_count_kv(self._num_kv_heads)
+ hparams = self.hparams
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+ if "head_dim" in hparams:
+ rope_dim = hparams["head_dim"]
+ else:
+ rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+ if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
+ if self.hparams["rope_scaling"].get("type") == "linear":
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
+
+ @staticmethod
+ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
+ if n_head_kv is not None and n_head != n_head_kv:
+ n_head = n_head_kv
+ return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(weights.shape))
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ n_head = self.hparams["num_attention_heads"]
+ if bid is not None:
+ if "num_key_value_heads_per_layer" in self.hparams:
+ n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
+ elif "block_configs" in self.hparams:
+ n_kv_head = self._num_kv_heads[bid]
+ n_head = self._num_heads[bid]
+ else:
+ n_kv_head = self.hparams.get("num_key_value_heads")
+ else:
+ n_kv_head = self.hparams.get("num_key_value_heads")
+
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = DeciModel.permute(data_torch, n_head, n_head)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = DeciModel.permute(data_torch, n_head, n_kv_head)
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
+ if rope_scaling.get("rope_type", '').lower() == "llama3":
+ base = self.hparams.get("rope_theta", 10000.0)
+ dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
+
+ factor = rope_scaling.get("factor", 8.0)
+ low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
+
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+ assert low_freq_wavelen != high_freq_wavelen
+
+ rope_factors = []
+ for freq in freqs:
+ wavelen = 2 * math.pi / freq
+ if wavelen < high_freq_wavelen:
+ rope_factors.append(1)
+ elif wavelen > low_freq_wavelen:
+ rope_factors.append(factor)
+ else:
+ smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
+ rope_factors.append(1 / ((1 - smooth) / factor + smooth))
+
+ yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+
@Model.register("BitnetForCausalLM")
@Model.register("BitNetForCausalLM")
class BitnetModel(Model):
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index d6feb9ea..e2f4eb1a 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -181,6 +181,7 @@ class GGUFType:
class MODEL_ARCH(IntEnum):
LLAMA = auto()
+ DECI = auto()
FALCON = auto()
BAICHUAN = auto()
GROK = auto()
@@ -316,6 +317,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
+ MODEL_ARCH.DECI: "deci",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
@@ -470,6 +472,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
+ MODEL_ARCH.DECI: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
MODEL_ARCH.GROK: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -1149,6 +1171,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
+ MODEL_ARCH.DECI: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
MODEL_ARCH.BAICHUAN: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 1dea6a82..3ff70cd7 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -176,7 +176,8 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
- "model.layers.{bid}.self_attn.o_proj", # llama-hf
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
+ "model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py
index dc574991..cca09798 100644
--- a/gguf-py/gguf/vocab.py
+++ b/gguf-py/gguf/vocab.py
@@ -122,8 +122,30 @@ class SpecialVocab:
tokenizer = json.load(f)
if self.load_merges:
merges = tokenizer.get('model', {}).get('merges')
- if isinstance(merges, list) and merges and isinstance(merges[0], str):
- self.merges = merges
+ if isinstance(merges, list) and merges:
+ if isinstance(merges[0], str):
+ self.merges = merges
+ elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
+ # New format since transformers 4.45 to support spaces in merges
+ # ref: https://github.com/ggml-org/llama.cpp/issues/9692
+ # TODO: internally store as the new format instead of converting to old
+ if any(' ' in s for pair in merges for s in pair):
+ logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
+ self.merges = [
+ ' '.join(
+ [
+ # ensure the spaces are properly encoded
+ ''.join(
+ chr(ord(c) + 256) if c == ' ' else c
+ for c in part
+ )
+ for part in pair
+ ]
+ )
+ for pair in merges
+ ]
+ else:
+ raise ValueError("Unknown tokenizer merges format")
added_tokens = tokenizer.get('added_tokens', {})
else:
added_tokens = {}
@@ -132,7 +154,12 @@ class SpecialVocab:
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
- chat_template = tokenizer_config.get('chat_template')
+ chat_template_alt = None
+ chat_template_file = path / 'chat_template.json'
+ if chat_template_file.is_file():
+ with open(chat_template_file, encoding = 'utf-8') as f:
+ chat_template_alt = json.load(f).get('chat_template')
+ chat_template = tokenizer_config.get('chat_template', chat_template_alt)
if chat_template is None or isinstance(chat_template, (str, list)):
self.chat_template = chat_template
else:
diff --git a/include/llama.h b/include/llama.h
index d7376d7d..e2901861 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -230,7 +230,8 @@ extern "C" {
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
- LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
+ LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
+ LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE,
};
enum llama_pooling_type {
diff --git a/src/llama.cpp b/src/llama.cpp
index dd9d82dc..9fbb58d1 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -185,6 +185,7 @@ static std::string format(const char * fmt, ...) {
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
+ LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
LLM_ARCH_GROK,
@@ -240,6 +241,7 @@ enum llm_arch {
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
+ { LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
{ LLM_ARCH_GPT2, "gpt2" },
@@ -282,7 +284,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_BITNET_25, "bitnet-25" },
- { LLM_ARCH_BITNET_B158, "bitnet-b1.58" },
+ { LLM_ARCH_BITNET_B158, "bitnet-b1.58" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
@@ -633,6 +635,32 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_DECI,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ },
+ },
+ {
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -2525,6 +2553,7 @@ enum e_model {
MODEL_70B,
MODEL_236B,
MODEL_314B,
+ MODEL_405B,
MODEL_671B,
MODEL_SMALL,
MODEL_MEDIUM,
@@ -5128,6 +5157,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_70B: return "70B";
case MODEL_236B: return "236B";
case MODEL_314B: return "314B";
+ case MODEL_405B: return "405B";
case MODEL_671B: return "671B";
case MODEL_SMALL: return "0.1B";
case MODEL_MEDIUM: return "0.4B";
@@ -5265,7 +5295,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) {
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158 || model.arch == LLM_ARCH_DECI) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
@@ -5320,6 +5350,16 @@ static void llm_load_hparams(
if (model.type == MODEL_17B_128E) {
hparams.use_kq_norm = false;
+ }
+ } break;
+ case LLM_ARCH_DECI:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 80: model.type = e_model::MODEL_70B; break;
+ case 162: model.type = e_model::MODEL_405B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MINICPM:
@@ -6956,6 +6996,76 @@ static bool llm_load_tensors(
}
}
} break;
+ case LLM_ARCH_DECI:
+ {
+ model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (model.output == NULL) {
+ model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i);
+ const int64_t n_ff = hparams.n_ff(i);
+ const int64_t n_head = hparams.n_head(i);
+ const int64_t n_head_kv = hparams.n_head_kv(i);
+
+ if (n_head_kv == 0 && n_head > 0) {
+ // linear attention for DeciLMCausalModel
+ layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ }
+ else if (n_head_kv > 0) {
+ layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.wq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
+ layer.wk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+ }
+
+ // optional bias tensors
+
+
+ layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ if (n_ff > 0) {
+ layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ }
+
+ if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+ layer.rope_long = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+ layer.rope_short = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+ }
+ else {
+ layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+ }
+
+ if (n_ff > 0) {
+ layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+
+ // optional MLP bias
+ layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.ffn_down_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.ffn_up_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ }
+ } break;
case LLM_ARCH_LLAMA4:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -10485,6 +10595,168 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_deci() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+ const int64_t n_head_kv = hparams.n_head_kv(il);
+ const int64_t n_head = hparams.n_head(il);
+ const int64_t n_ff = hparams.n_ff(il);
+
+ if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B
+ cur = inpL;
+ } else {
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+ }
+
+ if (n_head > 0 && n_head_kv == 0) { // "linear attention" of Llama-3_1-Nemotron-51B
+ cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+ cb(cur, "wo", il);
+ } else if (n_head > 0) {
+ // self-attention
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B
+ if (n_ff == 0) {
+ continue;
+ }
+
+ if (hparams.f_residual_scale) {
+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+ }
+
+ // modified to support attention-free layer of Llama-3_1-Nemotron-51B
+ struct ggml_tensor * ffn_inp = cur;
+ if (n_head > 0) {
+ ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+ }
+
+ // feed-forward network
+ if (model.layers[il].ffn_gate_inp == nullptr) {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
+ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ if (hparams.f_residual_scale) {
+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ if (hparams.f_logit_scale) {
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+ }
+
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_baichuan() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -16477,6 +16749,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_llama();
} break;
+ case LLM_ARCH_DECI:
+ {
+ result = llm.build_deci();
+ } break;
case LLM_ARCH_BAICHUAN:
{
result = llm.build_baichuan();
@@ -18997,8 +19273,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// - qs.n_attention_wv == 0 for Mamba models
// - qs.n_attention_wv == model.hparams.n_layer for Transformer models
// - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
+ // - model.arch == LLM_ARCH_DECI for Deci-Nemotron models
//
- GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+ GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
size_t total_size_org = 0;
size_t total_size_new = 0;
@@ -20298,6 +20575,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_DECI:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:
@@ -20371,6 +20649,10 @@ int32_t llama_n_layer(const struct llama_model * model) {
return model->hparams.n_layer;
}
+int32_t llama_n_head(const struct llama_model * model) {
+ return model->hparams.n_head();
+}
+
float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train;
}