summaryrefslogtreecommitdiff
path: root/gguf-py/gguf
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf')
-rw-r--r--gguf-py/gguf/constants.py41
-rw-r--r--gguf-py/gguf/gguf_writer.py12
-rw-r--r--gguf-py/gguf/tensor_mapping.py46
3 files changed, 97 insertions, 2 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index a6213981..b23badb1 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -61,6 +61,12 @@ class Keys:
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
+ class SSM:
+ CONV_KERNEL = "{arch}.ssm.conv_kernel"
+ INNER_SIZE = "{arch}.ssm.inner_size"
+ STATE_SIZE = "{arch}.ssm.state_size"
+ TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
+
class Tokenizer:
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
@@ -113,6 +119,7 @@ class MODEL_ARCH(IntEnum):
MINICPM = auto()
GEMMA = auto()
STARCODER2 = auto()
+ MAMBA = auto()
class MODEL_TENSOR(IntEnum):
@@ -144,6 +151,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
+ SSM_IN = auto()
+ SSM_CONV1D = auto()
+ SSM_X = auto()
+ SSM_DT = auto()
+ SSM_A = auto()
+ SSM_D = auto()
+ SSM_OUT = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -171,6 +185,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
+ MODEL_ARCH.MAMBA: "mamba",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -202,6 +217,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -543,6 +565,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.MAMBA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_X,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_OUT,
+ ],
# TODO
}
@@ -734,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
+# SSM
+KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
+KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
+KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
+KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
+
# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 80116083..e49c5db6 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -382,6 +382,18 @@ class GGUFWriter:
def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
+ def add_ssm_conv_kernel(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
+
+ def add_ssm_inner_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_state_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_time_step_rank(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
+
def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index db2ec970..ed89955d 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -20,6 +20,9 @@ class TensorNameMap:
"wte", # gpt2
"transformer.embd.wte", # phi2
"model.tok_embeddings", # internlm2
+ "model.embedding", # mamba-qbert
+ "backbone.embedding", # mamba
+ "backbone.embeddings", # mamba-hf
),
# Token type embeddings
@@ -44,7 +47,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@@ -61,6 +64,8 @@ class TensorNameMap:
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
"lm_head.ln", # phi2
+ "model.norm_f", # mamba-qbert
+ "backbone.norm_f", # mamba
),
# Rope frequencies
@@ -86,6 +91,8 @@ class TensorNameMap:
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
+ "model.layers.{bid}.norm", # mamba-qbert
+ "backbone.layers.{bid}.norm", # mamba
),
# Attention norm 2
@@ -282,7 +289,42 @@ class TensorNameMap:
MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert
- )
+ ),
+
+ MODEL_TENSOR.SSM_IN: (
+ "model.layers.{bid}.in_proj",
+ "backbone.layers.{bid}.mixer.in_proj",
+ ),
+
+ MODEL_TENSOR.SSM_CONV1D: (
+ "model.layers.{bid}.conv1d",
+ "backbone.layers.{bid}.mixer.conv1d",
+ ),
+
+ MODEL_TENSOR.SSM_X: (
+ "model.layers.{bid}.x_proj",
+ "backbone.layers.{bid}.mixer.x_proj",
+ ),
+
+ MODEL_TENSOR.SSM_DT: (
+ "model.layers.{bid}.dt_proj",
+ "backbone.layers.{bid}.mixer.dt_proj",
+ ),
+
+ MODEL_TENSOR.SSM_A: (
+ "model.layers.{bid}.A_log",
+ "backbone.layers.{bid}.mixer.A_log",
+ ),
+
+ MODEL_TENSOR.SSM_D: (
+ "model.layers.{bid}.D",
+ "backbone.layers.{bid}.mixer.D",
+ ),
+
+ MODEL_TENSOR.SSM_OUT: (
+ "model.layers.{bid}.out_proj",
+ "backbone.layers.{bid}.mixer.out_proj",
+ ),
}
mapping: dict[str, tuple[MODEL_TENSOR, str]]