summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xconvert-hf-to-gguf.py15
-rw-r--r--gguf-py/gguf/constants.py9
-rw-r--r--gguf-py/gguf/gguf_writer.py12
-rw-r--r--llama.cpp58
4 files changed, 83 insertions, 11 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index b51d6830..6d28ab5e 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1221,6 +1221,14 @@ class LlamaModel(Model):
except FileNotFoundError:
self._set_vocab_llama_hf()
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+ special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+ special_vocab._set_special_token("prefix", 32007)
+ special_vocab._set_special_token("suffix", 32008)
+ special_vocab._set_special_token("middle", 32009)
+ special_vocab._set_special_token("eot", 32010)
+ special_vocab.add_to_gguf(self.gguf_writer)
+
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
@@ -2240,6 +2248,13 @@ class GemmaModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+ special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+ special_vocab._set_special_token("prefix", 67)
+ special_vocab._set_special_token("suffix", 69)
+ special_vocab._set_special_token("middle", 68)
+ special_vocab._set_special_token("eot", 70)
+ special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
hparams = self.hparams
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 2566b2fb..1358206a 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -90,6 +90,11 @@ class Keys:
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
+ # FIM/Infill special tokens constants
+ PREFIX_ID = "tokenizer.ggml.prefix_token_id"
+ SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
+ MIDDLE_ID = "tokenizer.ggml.middle_token_id"
+ EOT_ID = "tokenizer.ggml.eot_token_id"
#
@@ -885,3 +890,7 @@ KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID
KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
+KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
+KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
+KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
+KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index f4c44076..ff9326d5 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -469,6 +469,18 @@ class GGUFWriter:
def add_chat_template(self, value: str) -> None:
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
+ def add_prefix_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
+
+ def add_suffix_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
+
+ def add_middle_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
+
+ def add_eot_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.EOT_ID, id)
+
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
diff --git a/llama.cpp b/llama.cpp
index a5ef2fd8..38e59362 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -327,6 +327,10 @@ enum llm_kv {
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
+ LLM_KV_TOKENIZER_PREFIX_ID,
+ LLM_KV_TOKENIZER_SUFFIX_ID,
+ LLM_KV_TOKENIZER_MIDDLE_ID,
+ LLM_KV_TOKENIZER_EOT_ID,
};
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@@ -399,6 +403,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
+ { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
+ { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
+ { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
+ { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
};
struct LLM_KV {
@@ -2055,10 +2063,10 @@ struct llama_vocab {
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
id linefeed_id = 13;
- id special_prefix_id = 32007;
- id special_middle_id = 32009;
- id special_suffix_id = 32008;
- id special_eot_id = 32010;
+ id special_prefix_id = -1;
+ id special_suffix_id = -1;
+ id special_middle_id = -1;
+ id special_eot_id = -1;
bool add_space_prefix = true;
@@ -4072,6 +4080,30 @@ static void llm_load_vocab(
vocab.special_cls_id = -1;
vocab.special_mask_id = -1;
+ // For Fill-In-the-Middle (FIM)/infill models which where converted
+ // prior to support of FIM special tokens in GGUF, the following
+ // will allow those models to continue to work. The general names
+ // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
+ // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
+ // new versions of these models have been published.
+ std::string gen_name;
+ ml.get_key(LLM_KV_GENERAL_NAME, gen_name);
+ std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
+ [](unsigned char c){ return std::tolower(c); });
+ if (gen_name.find("code") != std::string::npos) {
+ if (model.arch == LLM_ARCH_LLAMA) {
+ vocab.special_prefix_id = 32007;
+ vocab.special_suffix_id = 32008;
+ vocab.special_middle_id = 32009;
+ vocab.special_eot_id = 32010;
+ } else if (model.arch == LLM_ARCH_GEMMA) {
+ vocab.special_prefix_id = 67;
+ vocab.special_suffix_id = 69;
+ vocab.special_middle_id = 68;
+ vocab.special_eot_id = 70;
+ }
+ }
+
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
@@ -4185,13 +4217,17 @@ static void llm_load_vocab(
// special tokens
{
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
- { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
- { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
- { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
- { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
- { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
- { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
- { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
+ { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
+ { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
+ { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
+ { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
+ { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
+ { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
+ { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
+ { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
+ { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
+ { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
+ { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));