diff options
Diffstat (limited to 'gguf-py')
-rw-r--r-- | gguf-py/gguf/gguf.py | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index d377cd56..bda13ac0 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -79,6 +79,7 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" class MODEL_ARCH(IntEnum): LLAMA : int = auto() FALCON : int = auto() + BAICHUAN:int = auto() GPT2 : int = auto() GPTJ : int = auto() GPTNEOX: int = auto() @@ -108,6 +109,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN:"baichuan", MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTNEOX: "gptneox", @@ -153,6 +155,22 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", }, + MODEL_ARCH.BAICHUAN: { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + }, MODEL_ARCH.GPT2: { # TODO }, @@ -165,6 +183,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.BAICHUAN: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } @@ -187,7 +209,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf + "lm_head", # gpt2 mpt falcon llama-hf baichuan "output", # llama-pth ), @@ -195,7 +217,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 falcon - "model.norm", # llama-hf + "model.norm", # llama-hf baichuan "norm", # llama-pth ), |