summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index 92c4536c..3f6b7fe7 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -255,6 +255,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID,
+ LLM_KV_TOKENIZER_ADD_BOS,
+ LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
};
@@ -303,6 +305,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
+ { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
+ { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
};
@@ -1276,6 +1280,9 @@ struct llama_vocab {
id special_sep_id = -1;
id special_pad_id = -1;
+ int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
+ int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
+
id linefeed_id = 13;
id special_prefix_id = 32007;
id special_middle_id = 32009;
@@ -2388,6 +2395,23 @@ static void llm_load_vocab(
__func__, key.c_str(), id, old_id);
id = old_id;
}
+
+ }
+
+ // Handle add_bos_token and add_eos_token
+ std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
+ int kid = gguf_find_key(ctx, key.c_str());
+ enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
+ vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
+ if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
+ LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
+ }
+ key = kv(LLM_KV_TOKENIZER_ADD_EOS);
+ kid = gguf_find_key(ctx, key.c_str());
+ ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
+ vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
+ if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
+ LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
}
}
@@ -9288,6 +9312,14 @@ llama_token llama_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id;
}
+int llama_add_bos_token(const struct llama_model * model) {
+ return model->vocab.special_add_bos;
+}
+
+int llama_add_eos_token(const struct llama_model * model) {
+ return model->vocab.special_add_eos;
+}
+
llama_token llama_token_prefix(const struct llama_model * model) {
return model->vocab.special_prefix_id;
}