diff options
author | Julius Arkenberg <arki05@users.noreply.github.com> | 2024-03-23 17:41:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 18:41:53 +0200 |
commit | 476b0251b27fb64c575507024a671e639d675594 (patch) | |
tree | 444579ca4cde7158fb79768bba055dcfae3cdb53 /gguf-py/gguf/tensor_mapping.py | |
parent | 21cad01b6e6e1a96f99391f95e8ea8ae25c8288e (diff) |
llama : add grok-1 support (#6204)
* Add support for Grok model architecture
* Revert convert-hf-to-gguf to default options
* Fixed f_norm_rms_eps bug
* Fix whitespaces
* llama : fix grok rope type
* llama : minor
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'gguf-py/gguf/tensor_mapping.py')
-rw-r--r-- | gguf-py/gguf/tensor_mapping.py | 55 |
1 files changed, 35 insertions, 20 deletions
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index ed89955d..11fd34b8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -23,6 +23,7 @@ class TensorNameMap: "model.embedding", # mamba-qbert "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf + "transformer.in_out_embed", # Grok ), # Token type embeddings @@ -66,6 +67,7 @@ class TensorNameMap: "lm_head.ln", # phi2 "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba + "transformer.rms_norm", # Grok ), # Rope frequencies @@ -93,6 +95,7 @@ class TensorNameMap: "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba + "transformer.decoder_layer.{bid}.rms_norm", # Grok ), # Attention norm 2 @@ -116,32 +119,35 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf - "layers.{bid}.attention.wq", # llama-pth - "encoder.layer.{bid}.attention.self.query", # bert - "transformer.h.{bid}.attn.q_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.q_proj", # plamo - "model.layers.{bid}.attention.wq" # internlm2 + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.{bid}.attention.wq", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf - "layers.{bid}.attention.wk", # llama-pth - "encoder.layer.{bid}.attention.self.key", # bert - "transformer.h.{bid}.attn.k_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.k_proj", # plamo - "model.layers.{bid}.attention.wk" # internlm2 + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.{bid}.attention.wk", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf - "layers.{bid}.attention.wv", # llama-pth - "encoder.layer.{bid}.attention.self.value", # bert - "transformer.h.{bid}.attn.v_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.v_proj", # plamo - "model.layers.{bid}.attention.wv" # internlm2 + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.v_proj", # plamo + "model.layers.{bid}.attention.wv", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok ), # Attention output @@ -162,12 +168,14 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok ), # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layers.{bid}.norm1", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_1", # Grok ), # Rotary embeddings @@ -190,11 +198,13 @@ class TensorNameMap: "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 + "transformer.decoder_layer.{bid}.rms_norm_2", # Grok ), MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "transformer.decoder_layer.{bid}.router" # Grok ), # Feed-forward up @@ -223,6 +233,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok ), # AWQ-activation gate @@ -243,6 +254,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok ), # Feed-forward down @@ -270,6 +282,8 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok + ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -287,8 +301,9 @@ class TensorNameMap: ), MODEL_TENSOR.LAYER_OUT_NORM: ( - "encoder.layer.{bid}.output.LayerNorm", # bert - "encoder.layers.{bid}.norm2", # nomic-bert + "encoder.layer.{bid}.output.LayerNorm", # bert + "encoder.layers.{bid}.norm2", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_3", # Grok ), MODEL_TENSOR.SSM_IN: ( |