summaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert.py')
-rwxr-xr-xconvert.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/convert.py b/convert.py
index 24da25ef..0680f71e 100755
--- a/convert.py
+++ b/convert.py
@@ -369,7 +369,7 @@ class SentencePieceVocab:
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
- raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}")
+ raise Exception(f"Expected added token IDs to be sequential and start at {vocab_size}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_list = [text for (text, idx) in items]
@@ -1163,10 +1163,13 @@ def main(args_in: list[str] | None = None) -> None:
vocab: Vocab
if args.vocab_only:
- assert args.outfile, "need --outfile if using --vocab-only"
+ if not args.outfile:
+ raise ValueError("need --outfile if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
- special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
+ special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
+ load_merges = args.vocabtype == 'bpe',
+ n_vocab = vocab.vocab_size)
outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
print(f"Wrote {outfile}")
@@ -1178,7 +1181,9 @@ def main(args_in: list[str] | None = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype)
# FIXME: Try to respect vocab_dir somehow?
- special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
+ special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
+ load_merges = args.vocabtype == 'bpe',
+ n_vocab = vocab.vocab_size)
model = model_plus.model
model = convert_model_names(model, params)