summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md1
-rwxr-xr-xconvert-hf-to-gguf.py96
-rw-r--r--examples/eval-callback/eval-callback.cpp26
-rw-r--r--gguf-py/gguf/constants.py15
-rw-r--r--gguf-py/gguf/tensor_mapping.py58
-rw-r--r--llama.cpp375
-rwxr-xr-xscripts/get-wikitext-2.sh3
7 files changed, 427 insertions, 147 deletions
diff --git a/README.md b/README.md
index 00a487fc..fa2a4db8 100644
--- a/README.md
+++ b/README.md
@@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well.
- [x] LLaMA 2 🦙🦙
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
+- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
- [X] Falcon
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 63710676..e1ac09e0 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1427,6 +1427,102 @@ class GrokModel(Model):
self.gguf_writer.add_tensor(new_name, data)
+@Model.register("DbrxForCausalLM")
+class DbrxModel(Model):
+ model_arch = gguf.MODEL_ARCH.DBRX
+
+ def set_gguf_parameters(self):
+ ffn_config = self.hparams["ffn_config"]
+ attn_config = self.hparams["attn_config"]
+ self.gguf_writer.add_name(self.hparams["model_type"])
+ self.gguf_writer.add_block_count(self.hparams["n_layers"])
+
+ self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
+ self.gguf_writer.add_embedding_length(self.hparams["d_model"])
+ self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
+
+ self.gguf_writer.add_head_count(self.hparams["n_heads"])
+ self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
+
+ self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
+
+ self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
+ self.gguf_writer.add_file_type(self.ftype)
+
+ self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
+ self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
+
+ self.gguf_writer.add_layer_norm_eps(1e-5)
+
+ self.gguf_writer.add_file_type(self.ftype)
+ print(f"gguf: file type = {self.ftype}")
+
+ def write_tensors(self):
+ block_count = self.hparams.get("n_layers")
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ for name, data_torch in self.get_tensors():
+ n_expert = self.hparams["ffn_config"]["moe_num_experts"]
+ n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
+ n_embd = self.hparams["d_model"]
+
+ # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
+ # original implementation expects (n_expert, n_ff, n_embd) for all experts weights
+ # But llama.cpp moe graph works differently
+ # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
+ # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
+ exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
+ "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
+ "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
+ experts = False
+ for exp_tensor_name in exp_tensor_names.keys():
+ if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
+ experts = True
+ data_torch = data_torch.view(n_expert, n_ff, n_embd)
+ if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
+ data_torch = data_torch.permute(*permute_tensor)
+ break
+
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.squeeze().numpy()
+
+ # map tensor names
+ # In MoE models the ffn tensors are typically most of the model weights,
+ # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
+ # Every other model has the weight names ending in .weight,
+ # let's assume that is the convention which is not the case for dbrx:
+ # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
+ new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # Most of the codebase that takes in 1D tensors only handles F32 tensors
+ # and most of the outputs tensors are F32.
+ if data_dtype != np.float32 and n_dims == 1:
+ print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
+ sys.exit()
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp
index 05f7d6ab..29b5f3b3 100644
--- a/examples/eval-callback/eval-callback.cpp
+++ b/examples/eval-callback/eval-callback.cpp
@@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
}
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
+ GGML_ASSERT(n > 0);
float sum = 0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
printf(" [\n");
- for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
+ for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+ if (i2 == n && ne[2] > 2*n) {
+ printf(" ..., \n");
+ i2 = ne[2] - n;
+ }
printf(" [\n");
- for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
+ for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+ if (i1 == n && ne[1] > 2*n) {
+ printf(" ..., \n");
+ i1 = ne[1] - n;
+ }
printf(" [");
- for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
+ for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+ if (i0 == n && ne[0] > 2*n) {
+ printf("..., ");
+ i0 = ne[0] - n;
+ }
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
@@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} else {
GGML_ASSERT(false);
}
- printf("%8.4f", v);
+ printf("%12.4f", v);
sum += v;
- if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
+ if (i0 < ne[0] - 1) printf(", ");
}
- if (ne[0] > n) printf(", ...");
printf("],\n");
}
- if (ne[1] > n) printf(" ...\n");
printf(" ],\n");
}
- if (ne[2] > n) printf(" ...\n");
printf(" ]\n");
printf(" sum = %f\n", sum);
}
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index a6454a10..2566b2fb 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
+ DBRX = auto()
class MODEL_TENSOR(IntEnum):
@@ -195,6 +196,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
+ MODEL_ARCH.DBRX: "dbrx",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -642,6 +644,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
],
+ MODEL_ARCH.DBRX: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 4f02d298..ec6fcbb8 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
- "transformer.wte", # gpt2 gpt-j mpt refact qwen
+ "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
@@ -48,7 +48,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@@ -60,7 +60,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
- "transformer.norm_f", # mpt
+ "transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
@@ -96,6 +96,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
),
# Attention norm 2
@@ -108,6 +109,7 @@ class TensorNameMap:
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
+ "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
@@ -152,23 +154,24 @@ class TensorNameMap:
# Attention output
MODEL_TENSOR.ATTN_OUT: (
- "gpt_neox.layers.{bid}.attention.dense", # gptneox
- "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
- "transformer.blocks.{bid}.attn.out_proj", # mpt
- "transformer.h.{bid}.self_attention.dense", # falcon
- "h.{bid}.self_attention.dense", # bloom
- "model.layers.{bid}.self_attn.o_proj", # llama-hf
- "layers.{bid}.attention.wo", # llama-pth
- "encoder.layer.{bid}.attention.output.dense", # bert
- "transformer.h.{bid}.attn.out_proj", # gpt-j
- "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
- "model.layers.{bid}.self_attn.dense", # persimmon
- "h.{bid}.attn.c_proj", # gpt2
- "transformer.h.{bid}.mixer.out_proj", # phi2
- "model.layers.layers.{bid}.self_attn.o_proj", # plamo
- "model.layers.{bid}.attention.wo", # internlm2
- "encoder.layers.{bid}.attn.out_proj", # nomic-bert
- "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
+ "gpt_neox.layers.{bid}.attention.dense", # gptneox
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
+ "transformer.blocks.{bid}.attn.out_proj", # mpt
+ "transformer.h.{bid}.self_attention.dense", # falcon
+ "h.{bid}.self_attention.dense", # bloom
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf
+ "layers.{bid}.attention.wo", # llama-pth
+ "encoder.layer.{bid}.attention.output.dense", # bert
+ "transformer.h.{bid}.attn.out_proj", # gpt-j
+ "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
+ "model.layers.{bid}.self_attn.dense", # persimmon
+ "h.{bid}.attn.c_proj", # gpt2
+ "transformer.h.{bid}.mixer.out_proj", # phi2
+ "model.layers.layers.{bid}.self_attn.o_proj", # plamo
+ "model.layers.{bid}.attention.wo", # internlm2
+ "encoder.layers.{bid}.attn.out_proj", # nomic-bert
+ "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
),
# Attention output norm
@@ -176,6 +179,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),
# Rotary embeddings
@@ -202,9 +206,10 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_INP: (
- "layers.{bid}.feed_forward.gate", # mixtral
- "model.layers.{bid}.block_sparse_moe.gate", # mixtral
- "transformer.decoder_layer.{bid}.router" # Grok
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+ "transformer.decoder_layer.{bid}.router", # Grok
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
),
# Feed-forward up
@@ -233,6 +238,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
),
# AWQ-activation gate
@@ -251,8 +257,9 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_EXP: (
- "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
- "transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
+ "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
),
# Feed-forward down
@@ -280,6 +287,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
),
MODEL_TENSOR.ATTN_Q_NORM: (
diff --git a/llama.cpp b/llama.cpp
index 83dd55ef..b93c1abc 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -105,7 +105,7 @@
#endif
#define LLAMA_MAX_NODES 8192
-#define LLAMA_MAX_EXPERTS 8
+#define LLAMA_MAX_EXPERTS 16
//
@@ -220,6 +220,7 @@ enum llm_arch {
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
+ LLM_ARCH_DBRX,
LLM_ARCH_UNKNOWN,
};
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
+ { LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -935,6 +937,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_DBRX,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1707,6 +1725,7 @@ enum e_model {
MODEL_XL,
MODEL_8x7B,
MODEL_8x22B,
+ MODEL_16x12B,
};
static const size_t kiB = 1024;
@@ -3562,6 +3581,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_XL: return "1.5B";
case MODEL_8x7B: return "8x7B";
case MODEL_8x22B: return "8x22B";
+ case MODEL_16x12B: return "16x12B";
default: return "?B";
}
}
@@ -3983,6 +4003,16 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_DBRX:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
+
+ switch (hparams.n_layer) {
+ case 40: model.type = e_model::MODEL_16x12B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -4671,6 +4701,39 @@ static bool llm_load_tensors(
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
}
} break;
+ case LLM_ARCH_DBRX:
+ {
+ if (n_expert == 0) {
+ throw std::runtime_error("DBRX model cannot have zero experts");
+ }
+
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+ layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+
+ layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+ layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert});
+ layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
+ layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
+ }
+ } break;
case LLM_ARCH_BAICHUAN:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6433,62 +6496,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
- ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
- cb(logits, "ffn_moe_logits", il);
-
- ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
- cb(probs, "ffn_moe_probs", il);
-
- // select experts
- ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
-
- ggml_tensor * weights = ggml_get_rows(ctx0,
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
- cb(weights, "ffn_moe_weights", il);
-
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
-
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
- cb(weights_sum, "ffn_moe_weights_sum", il);
-
- weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
- cb(weights, "ffn_moe_weights_norm", il);
-
- // compute expert outputs
- ggml_tensor * moe_out = nullptr;
-
- for (int i = 0; i < n_expert_used; ++i) {
- ggml_tensor * cur_expert;
-
- ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
- cb(cur_up, "ffn_moe_up", il);
-
- ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
- cb(cur_gate, "ffn_moe_gate", il);
-
- cur_gate = ggml_silu(ctx0, cur_gate);
- cb(cur_gate, "ffn_moe_silu", il);
-
- cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
- cb(cur_expert, "ffn_moe_gate_par", il);
-
- cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
- cb(cur_expert, "ffn_moe_down", il);
-
- cur_expert = ggml_mul(ctx0, cur_expert,
- ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
- cb(cur_expert, "ffn_moe_weighted", il);
-
- if (i == 0) {
- moe_out = cur_expert;
- } else {
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
- cb(moe_out, "ffn_moe_out", il);
- }
- }
-
- cur = moe_out;
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@@ -6520,6 +6528,78 @@ struct llm_build_context {
return gf;
}
+ // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
+ ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
+ cb(logits, "ffn_moe_logits", il);
+
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
+ cb(probs, "ffn_moe_probs", il);
+
+ // select experts
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
+
+ ggml_tensor * weights = ggml_get_rows(ctx0,
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
+ cb(weights, "ffn_moe_weights", il);
+
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+ cb(weights_sum, "ffn_moe_weights_sum", il);
+
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+ cb(weights, "ffn_moe_weights_norm", il);
+
+ // compute expert outputs
+ ggml_tensor * moe_out = nullptr;
+
+ for (int i = 0; i < n_expert_used; ++i) {
+ ggml_tensor * cur_expert;
+
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
+ cb(cur_up, "ffn_moe_up", il);
+
+ ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
+ cb(gate, "ffn_moe_gate", il);
+
+ switch (type_op) {
+ case LLM_FFN_SILU:
+ {
+ gate = ggml_silu(ctx0, gate);
+ cb(gate, "ffn_moe_silu", il);
+ } break;
+ case LLM_FFN_GELU:
+ {
+ gate = ggml_gelu(ctx0, gate);
+ cb(gate, "ffn_moe_gelu", il);
+ } break;
+ default:
+ GGML_ASSERT(false);
+ }
+
+ cur_expert = ggml_mul(ctx0, cur_up, gate);
+ cb(cur_expert, "ffn_moe_gate_par", il);
+
+ cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+ cb(cur_expert, "ffn_moe_down", il);
+
+ cur_expert = ggml_mul(ctx0, cur_expert,
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+ cb(cur_expert, "ffn_moe_weighted", il);
+
+ if (i == 0) {
+ moe_out = cur_expert;
+ } else {
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
+ cb(moe_out, "ffn_moe_out", il);
+ }
+ }
+
+ return moe_out;
+ }
+
struct ggml_cgraph * build_baichuan() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -6967,74 +7047,143 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
- ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
- cb(logits, "ffn_moe_logits", il);
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
+
+ // Grok
+ // if layer_out_norm is present then apply it before adding the input
+ // Idea: maybe ffn_out_norm is a better name
+ if (model.layers[il].layer_out_norm) {
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].layer_out_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "layer_out_norm", il);
+ }
+
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+ if (layer_dir != nullptr) {
+ cur = ggml_add(ctx0, cur, layer_dir);
+ }
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
- ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
- cb(probs, "ffn_moe_probs", il);
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
- // select experts
- ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
- ggml_tensor * weights = ggml_get_rows(ctx0,
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
- cb(weights, "ffn_moe_weights", il);
+ // Grok
+ // multiply logits by output_multiplier_scale of 0.5773502691896257
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+ cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
- cb(weights_sum, "ffn_moe_weights_sum", il);
+ cb(cur, "result_output", -1);
- weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
- cb(weights, "ffn_moe_weights_norm", il);
+ ggml_build_forward_expand(gf, cur);
- // compute expert outputs
- ggml_tensor * moe_out = nullptr;
+ return gf;
+ }
- for (int i = 0; i < n_expert_used; ++i) {
- ggml_tensor * cur_expert;
+ struct ggml_cgraph * build_dbrx() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
- ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
- cb(cur_up, "ffn_moe_up", il);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
- ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
- cb(cur_gate, "ffn_moe_gate", il);
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
- //GeLU
- cur_gate = ggml_gelu(ctx0, cur_gate);
- cb(cur_gate, "ffn_moe_gelu", il);
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
- cb(cur_expert, "ffn_moe_gate_par", il);
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
- cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
- cb(cur_expert, "ffn_moe_down", il);
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
- cur_expert = ggml_mul(ctx0, cur_expert,
- ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
- cb(cur_expert, "ffn_moe_weighted", il);
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
- if (i == 0) {
- moe_out = cur_expert;
- } else {
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
- cb(moe_out, "ffn_moe_out", il);
- }
- }
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
- cur = moe_out;
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
- // Grok
- // if layer_out_norm is present then apply it before adding the input
- // Idea: maybe ffn_out_norm is a better name
- if (model.layers[il].layer_out_norm) {
- cur = llm_build_norm(ctx0, cur, hparams,
- model.layers[il].layer_out_norm, NULL,
- LLM_NORM_RMS, cb, il);
- cb(cur, "layer_out_norm", il);
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = nullptr;
+ struct ggml_tensor * Kcur = nullptr;
+ struct ggml_tensor * Vcur = nullptr;
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+ cb(cur, "wqkv_clamped", il);
+
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ // MoE branch
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].attn_out_norm, NULL,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_out_norm", il);
+
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@@ -7052,18 +7201,13 @@ struct llm_build_context {
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
- model.output_norm, NULL,
- LLM_NORM_RMS, cb, -1);
+ model.output_norm, NULL,
+ LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
- // Grok
- // multiply logits by output_multiplier_scale of 0.5773502691896257
-
- cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
-
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
@@ -9785,6 +9929,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_command_r();
} break;
+ case LLM_ARCH_DBRX:
+ {
+ result = llm.build_dbrx();
+ } break;
default:
GGML_ASSERT(false);
}
@@ -14638,6 +14786,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// the pairs of head values are offset by n_rot/2
case LLM_ARCH_FALCON:
case LLM_ARCH_GROK:
+ case LLM_ARCH_DBRX:
case LLM_ARCH_PERSIMMON:
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh
index 7ca760fa..b01476a4 100755
--- a/scripts/get-wikitext-2.sh
+++ b/scripts/get-wikitext-2.sh
@@ -1,10 +1,11 @@
#!/bin/bash
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
+unzip wikitext-2-raw-v1.zip
echo "Usage:"
echo ""
-echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]"
+echo " ./perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw [other params]"
echo ""
exit 0