summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/gguf.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py
index d377cd56..bda13ac0 100644
--- a/gguf-py/gguf/gguf.py
+++ b/gguf-py/gguf/gguf.py
@@ -79,6 +79,7 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum):
LLAMA : int = auto()
FALCON : int = auto()
+ BAICHUAN:int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX: int = auto()
@@ -108,6 +109,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
+ MODEL_ARCH.BAICHUAN:"baichuan",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
@@ -153,6 +155,22 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
+ MODEL_ARCH.BAICHUAN: {
+ MODEL_TENSOR.TOKEN_EMBD: "token_embd",
+ MODEL_TENSOR.OUTPUT_NORM: "output_norm",
+ MODEL_TENSOR.OUTPUT: "output",
+ MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
+ MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
+ MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
+ MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
+ MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
+ MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
+ MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+ MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
+ MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
+ MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
+ MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
+ },
MODEL_ARCH.GPT2: {
# TODO
},
@@ -165,6 +183,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
+ MODEL_ARCH.BAICHUAN: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
}
@@ -187,7 +209,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan
"output", # llama-pth
),
@@ -195,7 +217,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 falcon
- "model.norm", # llama-hf
+ "model.norm", # llama-hf baichuan
"norm", # llama-pth
),