summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authormomonga <115213907+mmnga@users.noreply.github.com>2023-09-03 14:36:28 +0900
committerGitHub <noreply@github.com>2023-09-03 08:36:28 +0300
commitc42f0ec6b344e14bd81c8612ab1445b3ff77358b (patch)
treeb1f1f34f8fc8b8a52771da9b387fcda4d8fafca7 /llama.cpp
parent2753415afdaf22a18c49608bd9d93cfffc05d435 (diff)
examples : fix gpt-neox (#2943)
Co-authored-by: mmnga <mmnga1mmnga@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp46
1 files changed, 44 insertions, 2 deletions
diff --git a/llama.cpp b/llama.cpp
index 3114d331..2b0cf30f 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -325,6 +325,44 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_GPT2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ },
+ },
+ {
+ LLM_ARCH_GPTJ,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ },
+ },
+ {
+ LLM_ARCH_GPTNEOX,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
+ {
+ LLM_ARCH_MPT,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ },
+ },
+ {
+ LLM_ARCH_UNKNOWN,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ },
+ },
};
static llm_arch llm_arch_from_string(const std::string & name) {
@@ -1605,9 +1643,13 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
- if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
+ if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
+ }
}
+ // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
+ // gpt-j n_rot = rotary_dim
}
// arch-specific KVs