summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
authorPierrick Hymbert <pierrick.hymbert@gmail.com>2024-04-13 11:33:52 +0200
committerGitHub <noreply@github.com>2024-04-13 11:33:52 +0200
commit4bd0f93e4ab4fe6682e7d0241c1bdec1397e954a (patch)
treeda912ccbf957473fb5aa6868c9cd73f0fcc42e63 /gguf-py
parentab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa (diff)
model: support arch `DbrxForCausalLM` (#6515)
* model: dbrx convert to gguf #6344 * llama: support dbrx #6344 * doc: dbrx: add the model as supported * scripts: get-wikitext-2 add unzip * llama: increase maximum experts allowed * llama: factorize moe graph implementation between grok, mixtral and dbrx --------- Co-authored-by: Megha Agarwal <16129366+megha95@users.noreply.github.com>
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/constants.py15
-rw-r--r--gguf-py/gguf/tensor_mapping.py58
2 files changed, 48 insertions, 25 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index a6454a10..2566b2fb 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
+ DBRX = auto()
class MODEL_TENSOR(IntEnum):
@@ -195,6 +196,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
+ MODEL_ARCH.DBRX: "dbrx",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -642,6 +644,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
],
+ MODEL_ARCH.DBRX: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 4f02d298..ec6fcbb8 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
- "transformer.wte", # gpt2 gpt-j mpt refact qwen
+ "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
@@ -48,7 +48,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@@ -60,7 +60,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
- "transformer.norm_f", # mpt
+ "transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
@@ -96,6 +96,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
),
# Attention norm 2
@@ -108,6 +109,7 @@ class TensorNameMap:
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
+ "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
@@ -152,23 +154,24 @@ class TensorNameMap:
# Attention output
MODEL_TENSOR.ATTN_OUT: (
- "gpt_neox.layers.{bid}.attention.dense", # gptneox
- "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
- "transformer.blocks.{bid}.attn.out_proj", # mpt
- "transformer.h.{bid}.self_attention.dense", # falcon
- "h.{bid}.self_attention.dense", # bloom
- "model.layers.{bid}.self_attn.o_proj", # llama-hf
- "layers.{bid}.attention.wo", # llama-pth
- "encoder.layer.{bid}.attention.output.dense", # bert
- "transformer.h.{bid}.attn.out_proj", # gpt-j
- "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
- "model.layers.{bid}.self_attn.dense", # persimmon
- "h.{bid}.attn.c_proj", # gpt2
- "transformer.h.{bid}.mixer.out_proj", # phi2
- "model.layers.layers.{bid}.self_attn.o_proj", # plamo
- "model.layers.{bid}.attention.wo", # internlm2
- "encoder.layers.{bid}.attn.out_proj", # nomic-bert
- "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
+ "gpt_neox.layers.{bid}.attention.dense", # gptneox
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
+ "transformer.blocks.{bid}.attn.out_proj", # mpt
+ "transformer.h.{bid}.self_attention.dense", # falcon
+ "h.{bid}.self_attention.dense", # bloom
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf
+ "layers.{bid}.attention.wo", # llama-pth
+ "encoder.layer.{bid}.attention.output.dense", # bert
+ "transformer.h.{bid}.attn.out_proj", # gpt-j
+ "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
+ "model.layers.{bid}.self_attn.dense", # persimmon
+ "h.{bid}.attn.c_proj", # gpt2
+ "transformer.h.{bid}.mixer.out_proj", # phi2
+ "model.layers.layers.{bid}.self_attn.o_proj", # plamo
+ "model.layers.{bid}.attention.wo", # internlm2
+ "encoder.layers.{bid}.attn.out_proj", # nomic-bert
+ "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
),
# Attention output norm
@@ -176,6 +179,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),
# Rotary embeddings
@@ -202,9 +206,10 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_INP: (
- "layers.{bid}.feed_forward.gate", # mixtral
- "model.layers.{bid}.block_sparse_moe.gate", # mixtral
- "transformer.decoder_layer.{bid}.router" # Grok
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+ "transformer.decoder_layer.{bid}.router", # Grok
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
),
# Feed-forward up
@@ -233,6 +238,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
),
# AWQ-activation gate
@@ -251,8 +257,9 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_EXP: (
- "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
- "transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
+ "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
),
# Feed-forward down
@@ -280,6 +287,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
),
MODEL_TENSOR.ATTN_Q_NORM: (