diff options
author | Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> | 2023-08-30 02:25:50 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-30 11:25:50 +0300 |
commit | dc07dc492ef9640bbb82904d7c7679f7bdcf6d76 (patch) | |
tree | f9d80bc6ee29067e8e72521d75dfa2b92d85540e /convert.py | |
parent | ad9ddcff6ef322db5cf13785bd7c856b610d242e (diff) |
convert : various script cleanups/fixes + merges and special token handling (#2842)
* convert: Fix permute calls and method/func definitions
* Cleanups for gguf-py
* Minor types cleanups.
* Initial implementation of handling merges and special tokens
* convert: Handle special tokens and merges in vocab only mode
convert: Vocab only mode no longer requires loading model tensors
* gguf: Refactor tensor name mapping
* convert: Fix type hint for special_token_types in SpecialVocab
* Use common special vocab handling in various conversion scripts
* First pass at implementing suggested changes
* Second pass
* gguf: SpecialVocab: Fix issue with special token content not in a dict
gguf: SpecialVocab: Allow skipping handling of merges
* convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json
* convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer
* gguf: SpecialVocab: Actually set load_merges in object
* Uniform args parsing and vocab only mode for convert examples
* convert.py: Set gpt2 as tokenizer model when using BPE
* Squish last type warning in gguf.py - yay!
Diffstat (limited to 'convert.py')
-rwxr-xr-x | convert.py | 148 |
1 files changed, 85 insertions, 63 deletions
@@ -25,7 +25,7 @@ import numpy as np from abc import ABCMeta, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union) +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) from sentencepiece import SentencePieceProcessor # type: ignore if TYPE_CHECKING: @@ -299,8 +299,10 @@ class Params: params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) elif orig_config_path.exists(): params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) - else: + elif model_plus.format != 'none': params = Params.guessed(model_plus.model) + else: + raise ValueError('Cannot guess params when model format is none') params.path_model = model_plus.paths[0].parent @@ -353,7 +355,7 @@ class BpeVocab: yield from self.added_tokens() def __repr__(self) -> str: - return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" + return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" class SentencePieceVocab: @@ -416,7 +418,6 @@ class SentencePieceVocab: Vocab = Union[BpeVocab, SentencePieceVocab] - # # data loading # TODO: reuse (probably move to gguf.py?) @@ -439,14 +440,14 @@ class Tensor(metaclass=ABCMeta): @abstractmethod def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ... @abstractmethod - def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ... @abstractmethod def part(self, n_part: int) -> 'UnquantizedTensor': ... @abstractmethod def to_ggml(self) -> 'GGMLCompatibleTensor': ... -def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray: +def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" fp32_arr = bf16_arr.astype(np.uint32) << 16 return fp32_arr.view(np.float32) @@ -467,9 +468,9 @@ class UnquantizedTensor(Tensor): def to_ggml(self) -> 'UnquantizedTensor': return self - def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head)) + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) def part(self, n_part: int) -> 'UnquantizedTensor': r = self.ndarray.shape[0] // 3 @@ -531,7 +532,7 @@ LazyModel = Dict[str, LazyTensor] class ModelPlus: model: LazyModel paths: List[Path] # Where this was read from. - format: Literal['ggml', 'torch', 'safetensors'] + format: Literal['ggml', 'torch', 'safetensors', 'none'] vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab. @@ -597,12 +598,12 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTe return lazy_tensor.load().permute(n_head, n_head_kv) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) -def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: def load() -> Tensor: - return lazy_tensor.load().permute_part(n_part, n_head) + return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: def load() -> Tensor: @@ -657,7 +658,7 @@ class LazyUnpickler(pickle.Unpickler): description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' return LazyStorage(load=load, kind=pid[1], description=description) - # @staticmethod + @staticmethod def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName] requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: @@ -669,13 +670,15 @@ class LazyUnpickler(pickle.Unpickler): description = f'pickled storage_offset={storage_offset} in {storage.description}' return LazyTensor(load, list(size), storage.kind.data_type, description) - # @staticmethod + @staticmethod def rebuild_from_type_v2(func, new_type, args, state): return func(*args) - CLASSES: Dict[Any, Any] = { - ('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2, - ('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2, + CLASSES: Dict[Tuple[str, str], Any] = { + # getattr used here as a workaround for mypy not being smart enough to detrmine + # the staticmethods have a __func__ attribute. + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), @@ -751,7 +754,7 @@ def lazy_load_file(path: Path) -> ModelPlus: In = TypeVar('In') Out = TypeVar('Out') -def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]: +def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]: '''Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one @@ -760,7 +763,12 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield from map(func, iterable) # Not reached. iterable = iter(iterable) - with factory(max_workers = max_workers) as executor: + executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]] + if use_processpool_executor: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + with executor_class(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] done = False for _ in range(concurrency): @@ -838,11 +846,19 @@ class OutputFile: scores.append(score) toktypes.append(toktype) - self.gguf.add_tokenizer_model("llama") + if isinstance(vocab, SentencePieceVocab): + self.gguf.add_tokenizer_model("llama") + elif isinstance(vocab, BpeVocab): + self.gguf.add_tokenizer_model("gpt2") + else: + raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab') self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) self.gguf.add_token_types(toktypes) + def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: + svocab.add_to_gguf(self.gguf) + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) raw_dtype = getattr(tensor.data_type, 'ggml_type', None) @@ -861,7 +877,7 @@ class OutputFile: self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: + def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -869,6 +885,8 @@ class OutputFile: # meta data of.add_meta_arch(params) of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + of.write_meta() of.close() @@ -887,7 +905,7 @@ class OutputFile: return dt.quantize(arr) @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -895,6 +913,7 @@ class OutputFile: # meta data of.add_meta_arch(params) of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) # tensor info for name, lazy_tensor in model.items(): @@ -906,7 +925,7 @@ class OutputFile: # tensor data ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) @@ -939,7 +958,8 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM for (name, tensor) in model.items()} def convert_model_names(model: LazyModel, params: Params) -> LazyModel: - tmap = gguf.get_tensor_name_map(ARCH, params.n_layer) + tmap = gguf.TensorNameMap(ARCH, params.n_layer) + should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) tmp = model @@ -952,8 +972,8 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: #tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] elif f"model.layers.{i}.self_attn.W_pack.weight" in model: print(f"Unpacking and permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head) + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] else: @@ -961,23 +981,16 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: out: LazyModel = {} for name, lazy_tensor in model.items(): - name_new = name - - if name in tmap: - name_new = tmap[name] - elif name.endswith(".weight") and name[:-7] in tmap: - name_new = tmap[name[:-7]] + ".weight" - elif name.endswith(".bias") and name[:-5] in tmap: - name_new = tmap[name[:-5]] + ".bias" - else: + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + if name_new is None: raise Exception(f"Unexpected tensor name: {name}") - if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new): + if tensor_type in should_skip: print(f"skipping tensor {name_new}") continue - else: - print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") - out[name_new] = lazy_tensor + + print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor return out @@ -1117,8 +1130,16 @@ def main(args_in: Optional[List[str]] = None) -> None: if args.dump_single: model_plus = lazy_load_file(args.model) do_dump_model(model_plus) + return - model_plus = load_some_model(args.model) + if not args.vocab_only: + model_plus = load_some_model(args.model) + else: + model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + + if args.dump: + do_dump_model(model_plus) + return params = Params.load(model_plus) if params.n_ctx == -1: @@ -1140,33 +1161,34 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab: Vocab if args.vocab_only: - vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) assert args.outfile, "need --outfile if using --vocab-only" + # FIXME: Try to respect vocab_dir somehow? + vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) + special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe') outfile = args.outfile - OutputFile.write_vocab_only(outfile, params, vocab) + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) print(f"Wrote {outfile}") - else: - if args.dump: - do_dump_model(model_plus) - return + return - if model_plus.vocab is not None and args.vocab_dir is None: - vocab = model_plus.vocab - else: - vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent - vocab = load_vocab(vocab_dir, args.vocabtype) - - model = model_plus.model - model = convert_model_names(model, params) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) - outfile = args.outfile or default_outfile(model_plus.paths, ftype) - - params.ftype = ftype - print(f"Writing {outfile}, format {ftype}") - - OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency) - print(f"Wrote {outfile}") + if model_plus.vocab is not None and args.vocab_dir is None: + vocab = model_plus.vocab + else: + vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent + vocab = load_vocab(vocab_dir, args.vocabtype) + # FIXME: Try to respect vocab_dir somehow? + special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe') + + model = model_plus.model + model = convert_model_names(model, params) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype) + + params.ftype = ftype + print(f"Writing {outfile}, format {ftype}") + + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency) + print(f"Wrote {outfile}") if __name__ == '__main__': |