summaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-08-30 02:25:50 -0600
committerGitHub <noreply@github.com>2023-08-30 11:25:50 +0300
commitdc07dc492ef9640bbb82904d7c7679f7bdcf6d76 (patch)
treef9d80bc6ee29067e8e72521d75dfa2b92d85540e /convert.py
parentad9ddcff6ef322db5cf13785bd7c856b610d242e (diff)
convert : various script cleanups/fixes + merges and special token handling (#2842)
* convert: Fix permute calls and method/func definitions * Cleanups for gguf-py * Minor types cleanups. * Initial implementation of handling merges and special tokens * convert: Handle special tokens and merges in vocab only mode convert: Vocab only mode no longer requires loading model tensors * gguf: Refactor tensor name mapping * convert: Fix type hint for special_token_types in SpecialVocab * Use common special vocab handling in various conversion scripts * First pass at implementing suggested changes * Second pass * gguf: SpecialVocab: Fix issue with special token content not in a dict gguf: SpecialVocab: Allow skipping handling of merges * convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json * convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer * gguf: SpecialVocab: Actually set load_merges in object * Uniform args parsing and vocab only mode for convert examples * convert.py: Set gpt2 as tokenizer model when using BPE * Squish last type warning in gguf.py - yay!
Diffstat (limited to 'convert.py')
-rwxr-xr-xconvert.py148
1 files changed, 85 insertions, 63 deletions
diff --git a/convert.py b/convert.py
index 3f0a1c93..448b6f0f 100755
--- a/convert.py
+++ b/convert.py
@@ -25,7 +25,7 @@ import numpy as np
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from pathlib import Path
-from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore
if TYPE_CHECKING:
@@ -299,8 +299,10 @@ class Params:
params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
elif orig_config_path.exists():
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
- else:
+ elif model_plus.format != 'none':
params = Params.guessed(model_plus.model)
+ else:
+ raise ValueError('Cannot guess params when model format is none')
params.path_model = model_plus.paths[0].parent
@@ -353,7 +355,7 @@ class BpeVocab:
yield from self.added_tokens()
def __repr__(self) -> str:
- return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+ return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab:
@@ -416,7 +418,6 @@ class SentencePieceVocab:
Vocab = Union[BpeVocab, SentencePieceVocab]
-
#
# data loading
# TODO: reuse (probably move to gguf.py?)
@@ -439,14 +440,14 @@ class Tensor(metaclass=ABCMeta):
@abstractmethod
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
@abstractmethod
- def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
+ def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ...
@abstractmethod
def part(self, n_part: int) -> 'UnquantizedTensor': ...
@abstractmethod
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
-def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray:
+def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
fp32_arr = bf16_arr.astype(np.uint32) << 16
return fp32_arr.view(np.float32)
@@ -467,9 +468,9 @@ class UnquantizedTensor(Tensor):
def to_ggml(self) -> 'UnquantizedTensor':
return self
- def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
+ def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
r = self.ndarray.shape[0] // 3
- return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head))
+ return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
def part(self, n_part: int) -> 'UnquantizedTensor':
r = self.ndarray.shape[0] // 3
@@ -531,7 +532,7 @@ LazyModel = Dict[str, LazyTensor]
class ModelPlus:
model: LazyModel
paths: List[Path] # Where this was read from.
- format: Literal['ggml', 'torch', 'safetensors']
+ format: Literal['ggml', 'torch', 'safetensors', 'none']
vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab.
@@ -597,12 +598,12 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTe
return lazy_tensor.load().permute(n_head, n_head_kv)
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
-def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
+def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor:
def load() -> Tensor:
- return lazy_tensor.load().permute_part(n_part, n_head)
+ return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv)
s = lazy_tensor.shape.copy()
s[0] = s[0] // 3
- return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
+ return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
def load() -> Tensor:
@@ -657,7 +658,7 @@ class LazyUnpickler(pickle.Unpickler):
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
return LazyStorage(load=load, kind=pid[1], description=description)
- # @staticmethod
+ @staticmethod
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
# pyright: ignore[reportSelfClsParameterName]
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
@@ -669,13 +670,15 @@ class LazyUnpickler(pickle.Unpickler):
description = f'pickled storage_offset={storage_offset} in {storage.description}'
return LazyTensor(load, list(size), storage.kind.data_type, description)
- # @staticmethod
+ @staticmethod
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
- CLASSES: Dict[Any, Any] = {
- ('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2,
- ('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2,
+ CLASSES: Dict[Tuple[str, str], Any] = {
+ # getattr used here as a workaround for mypy not being smart enough to detrmine
+ # the staticmethods have a __func__ attribute.
+ ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
+ ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
@@ -751,7 +754,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In')
Out = TypeVar('Out')
-def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
+def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one
@@ -760,7 +763,12 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield from map(func, iterable)
# Not reached.
iterable = iter(iterable)
- with factory(max_workers = max_workers) as executor:
+ executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
+ if use_processpool_executor:
+ executor_class = ProcessPoolExecutor
+ else:
+ executor_class = ThreadPoolExecutor
+ with executor_class(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
done = False
for _ in range(concurrency):
@@ -838,11 +846,19 @@ class OutputFile:
scores.append(score)
toktypes.append(toktype)
- self.gguf.add_tokenizer_model("llama")
+ if isinstance(vocab, SentencePieceVocab):
+ self.gguf.add_tokenizer_model("llama")
+ elif isinstance(vocab, BpeVocab):
+ self.gguf.add_tokenizer_model("gpt2")
+ else:
+ raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab')
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes)
+ def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
+ svocab.add_to_gguf(self.gguf)
+
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = int(np.prod(tensor.shape))
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
@@ -861,7 +877,7 @@ class OutputFile:
self.gguf.close()
@staticmethod
- def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
+ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
@@ -869,6 +885,8 @@ class OutputFile:
# meta data
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
+ of.add_meta_special_vocab(svocab)
+
of.write_meta()
of.close()
@@ -887,7 +905,7 @@ class OutputFile:
return dt.quantize(arr)
@staticmethod
- def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
+ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
@@ -895,6 +913,7 @@ class OutputFile:
# meta data
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
+ of.add_meta_special_vocab(svocab)
# tensor info
for name, lazy_tensor in model.items():
@@ -906,7 +925,7 @@ class OutputFile:
# tensor data
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0:
- ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
+ ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
else:
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
@@ -939,7 +958,8 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
for (name, tensor) in model.items()}
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
- tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
+ tmap = gguf.TensorNameMap(ARCH, params.n_layer)
+ should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model
@@ -952,8 +972,8 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
#tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
print(f"Unpacking and permuting layer {i}")
- tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
- tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
+ tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
+ tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
else:
@@ -961,23 +981,16 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {}
for name, lazy_tensor in model.items():
- name_new = name
-
- if name in tmap:
- name_new = tmap[name]
- elif name.endswith(".weight") and name[:-7] in tmap:
- name_new = tmap[name[:-7]] + ".weight"
- elif name.endswith(".bias") and name[:-5] in tmap:
- name_new = tmap[name[:-5]] + ".bias"
- else:
+ tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
+ if name_new is None:
raise Exception(f"Unexpected tensor name: {name}")
- if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new):
+ if tensor_type in should_skip:
print(f"skipping tensor {name_new}")
continue
- else:
- print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
- out[name_new] = lazy_tensor
+
+ print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
+ out[name_new] = lazy_tensor
return out
@@ -1117,8 +1130,16 @@ def main(args_in: Optional[List[str]] = None) -> None:
if args.dump_single:
model_plus = lazy_load_file(args.model)
do_dump_model(model_plus)
+ return
- model_plus = load_some_model(args.model)
+ if not args.vocab_only:
+ model_plus = load_some_model(args.model)
+ else:
+ model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
+
+ if args.dump:
+ do_dump_model(model_plus)
+ return
params = Params.load(model_plus)
if params.n_ctx == -1:
@@ -1140,33 +1161,34 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab: Vocab
if args.vocab_only:
- vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
assert args.outfile, "need --outfile if using --vocab-only"
+ # FIXME: Try to respect vocab_dir somehow?
+ vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
+ special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
outfile = args.outfile
- OutputFile.write_vocab_only(outfile, params, vocab)
+ OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
print(f"Wrote {outfile}")
- else:
- if args.dump:
- do_dump_model(model_plus)
- return
+ return
- if model_plus.vocab is not None and args.vocab_dir is None:
- vocab = model_plus.vocab
- else:
- vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
- vocab = load_vocab(vocab_dir, args.vocabtype)
-
- model = model_plus.model
- model = convert_model_names(model, params)
- ftype = pick_output_type(model, args.outtype)
- model = convert_to_output_type(model, ftype)
- outfile = args.outfile or default_outfile(model_plus.paths, ftype)
-
- params.ftype = ftype
- print(f"Writing {outfile}, format {ftype}")
-
- OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
- print(f"Wrote {outfile}")
+ if model_plus.vocab is not None and args.vocab_dir is None:
+ vocab = model_plus.vocab
+ else:
+ vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
+ vocab = load_vocab(vocab_dir, args.vocabtype)
+ # FIXME: Try to respect vocab_dir somehow?
+ special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
+
+ model = model_plus.model
+ model = convert_model_names(model, params)
+ ftype = pick_output_type(model, args.outtype)
+ model = convert_to_output_type(model, ftype)
+ outfile = args.outfile or default_outfile(model_plus.paths, ftype)
+
+ params.ftype = ftype
+ print(f"Writing {outfile}, format {ftype}")
+
+ OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency)
+ print(f"Wrote {outfile}")
if __name__ == '__main__':