diff options
author | Cebtenzzre <cebtenzzre@gmail.com> | 2023-08-31 01:02:23 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 08:02:23 +0300 |
commit | 92d0b751a77a089e650983e9f1564ef4d31b32b9 (patch) | |
tree | 28beced44af777ec563f489518a14509994187dd /convert.py | |
parent | 8afe2280009ecbfc9de2c93b8f41283dc810609a (diff) |
convert : fix python 3.8 support, modernize type annotations (#2916)
* convert : fix python 3.8 support
* convert : sort imports
* convert : fix required parameters in convert-llama-ggmlv3-to-gguf
* convert : fix mypy errors in convert-llama-ggmlv3-to-gguf
* convert : use PEP 585 generics and PEP 604 unions
Now that we have `from __future__ import annotations`, we can use this
modern syntax in Python 3.7 instead of restricting support to Python 3.9
or 3.10 respectively.
* gguf.py : a tuple is already a tuple
* add mypy.ini
* convert : add necessary `type: ignore` comments
* gguf-py: bump version
Diffstat (limited to 'convert.py')
-rwxr-xr-x | convert.py | 149 |
1 files changed, 75 insertions, 74 deletions
@@ -1,9 +1,8 @@ #!/usr/bin/env python3 +from __future__ import annotations -import gguf import argparse import concurrent.futures -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import copy import enum import faulthandler @@ -20,21 +19,23 @@ import struct import sys import time import zipfile -import numpy as np - from abc import ABCMeta, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) -from sentencepiece import SentencePieceProcessor # type: ignore +from typing import IO, TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal, Sequence, TypeVar + +import gguf +import numpy as np +from sentencepiece import SentencePieceProcessor # type: ignore[import] if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing import TypeAlias if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): faulthandler.register(signal.SIGUSR1) -NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' +NDArray: TypeAlias = 'np.ndarray[Any, Any]' ARCH=gguf.MODEL_ARCH.LLAMA NAMES=gguf.MODEL_TENSOR_NAMES[ARCH] @@ -47,8 +48,8 @@ DEFAULT_CONCURRENCY = 8 @dataclass(frozen=True) class DataType: name: str - dtype: 'np.dtype[Any]' - valid_conversions: List[str] + dtype: np.dtype[Any] + valid_conversions: list[str] def elements_to_bytes(self, n_elements: int) -> int: return n_elements * self.dtype.itemsize @@ -65,7 +66,7 @@ DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_convers @dataclass(frozen=True) class QuantizedDataType(DataType): block_size: int - quantized_dtype: 'np.dtype[Any]' + quantized_dtype: np.dtype[Any] ggml_type: gguf.GGMLQuantizationType def quantize(self, arr: NDArray) -> NDArray: @@ -84,7 +85,7 @@ class Q8_0QuantizedDataType(QuantizedDataType): n_blocks = arr.size // self.block_size blocks = arr.reshape((n_blocks, self.block_size)) # Much faster implementation of block quantization contributed by @Cebtenzzre - def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]: + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: d = abs(blocks).max(axis = 1) / np.float32(127) with np.errstate(divide = 'ignore'): qs = (blocks / d[:, None]).round() @@ -98,13 +99,13 @@ DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))])) # Quantized types skipped here because they may also map to np.float32 -NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {} +NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {} for dt in (DT_BF16, DT_F16, DT_F32, DT_I32): if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE: raise ValueError(f'Invalid duplicate data type {dt}') NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt -SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { +SAFETENSORS_DATA_TYPES: dict[str, DataType] = { 'BF16': DT_BF16, 'F16': DT_F16, 'F32': DT_F32, @@ -119,14 +120,14 @@ class GGMLFileType(enum.IntEnum): MostlyF16 = 1 # except 1d tensors MostlyQ8_0 = 7 # except 1d tensors - def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: + def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) if dt is None: raise ValueError(self) # 1D tensors are always F32. return dt if len(tensor.shape) > 1 else DT_F32 -GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = { +GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { GGMLFileType.AllF32 : DT_F32, GGMLFileType.MostlyF16 : DT_F16, GGMLFileType.MostlyQ8_0: DT_Q8_0, @@ -148,13 +149,13 @@ class Params: n_head_kv: int f_norm_eps: float - f_rope_freq_base: Optional[float] = None - f_rope_scale: Optional[float] = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None - ftype: Optional[GGMLFileType] = None + ftype: GGMLFileType | None = None # path to the directory containing the model files - path_model: Optional['Path'] = None + path_model: Path | None = None @staticmethod def find_n_mult(n_ff: int, n_embd: int) -> int: @@ -166,7 +167,7 @@ class Params: raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") @staticmethod - def guessed(model: 'LazyModel') -> 'Params': + def guessed(model: LazyModel) -> Params: # try transformer naming first n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape @@ -202,7 +203,7 @@ class Params: ) @staticmethod - def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) n_vocab = config["vocab_size"] @@ -247,7 +248,7 @@ class Params: # LLaMA v2 70B params.json # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1 @staticmethod - def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) n_vocab = config["vocab_size"] if "vocab_size" in config else -1 @@ -291,7 +292,7 @@ class Params: ) @staticmethod - def load(model_plus: 'ModelPlus') -> 'Params': + def load(model_plus: ModelPlus) -> Params: hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" @@ -314,9 +315,9 @@ class Params: # class BpeVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) - added_tokens: Dict[str, int] + added_tokens: dict[str, int] if fname_added_tokens is not None: added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: @@ -335,9 +336,9 @@ class BpeVocab: self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - def bpe_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.bpe_tokenizer - from transformers.models.gpt2 import tokenization_gpt2 + from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import] byte_encoder = tokenization_gpt2.bytes_to_unicode() byte_decoder = {v: k for k, v in byte_encoder.items()} for i, item in enumerate(tokenizer): @@ -345,12 +346,12 @@ class BpeVocab: score: float = -i yield text, score, gguf.TokenType.USER_DEFINED - def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: score = -1000.0 yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED - def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.bpe_tokens() yield from self.added_tokens() @@ -359,9 +360,9 @@ class BpeVocab: class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) - added_tokens: Dict[str, int] + added_tokens: dict[str, int] if fname_added_tokens is not None: added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: @@ -380,7 +381,7 @@ class SentencePieceVocab: self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer for i in range(tokenizer.vocab_size()): piece = tokenizer.id_to_piece(i) @@ -404,19 +405,19 @@ class SentencePieceVocab: yield text, score, toktype - def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: score = -1000.0 yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED - def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() def __repr__(self) -> str: return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" -Vocab = Union[BpeVocab, SentencePieceVocab] +Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab' # # data loading @@ -436,15 +437,15 @@ class Tensor(metaclass=ABCMeta): data_type: DataType @abstractmethod - def astype(self, data_type: DataType) -> 'Tensor': ... + def astype(self, data_type: DataType) -> Tensor: ... @abstractmethod - def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ... + def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... @abstractmethod - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ... + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... @abstractmethod - def part(self, n_part: int) -> 'UnquantizedTensor': ... + def part(self, n_part: int) -> UnquantizedTensor: ... @abstractmethod - def to_ggml(self) -> 'GGMLCompatibleTensor': ... + def to_ggml(self) -> GGMLCompatibleTensor: ... def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: @@ -465,22 +466,22 @@ class UnquantizedTensor(Tensor): self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) - def to_ggml(self) -> 'UnquantizedTensor': + def to_ggml(self) -> UnquantizedTensor: return self - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) - def part(self, n_part: int) -> 'UnquantizedTensor': + def part(self, n_part: int) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) - def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': + def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) -def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray: +def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: tensor = lazy_tensor.load() assert isinstance(tensor, UnquantizedTensor) @@ -496,13 +497,13 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv return tensor.ndarray -GGMLCompatibleTensor = Union[UnquantizedTensor] +GGMLCompatibleTensor = UnquantizedTensor @dataclass class LazyTensor: _load: Callable[[], Tensor] - shape: List[int] + shape: list[int] data_type: DataType description: str @@ -513,7 +514,7 @@ class LazyTensor: (self.data_type, ret.data_type, self.description) return ret - def astype(self, data_type: DataType) -> 'LazyTensor': + def astype(self, data_type: DataType) -> LazyTensor: self.validate_conversion_to(data_type) def load() -> Tensor: @@ -525,24 +526,24 @@ class LazyTensor: raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') -LazyModel = Dict[str, LazyTensor] +LazyModel = dict[str, LazyTensor] @dataclass class ModelPlus: model: LazyModel - paths: List[Path] # Where this was read from. + paths: list[Path] # Where this was read from. format: Literal['ggml', 'torch', 'safetensors', 'none'] - vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab. + vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. -def merge_sharded(models: List[LazyModel]) -> LazyModel: +def merge_sharded(models: list[LazyModel]) -> LazyModel: # Original LLaMA models have each file contain one part of each tensor. # Use a dict instead of a set to preserve order. names = {name: None for model in models for name in model} def convert(name: str) -> LazyTensor: - lazy_tensors: List[LazyTensor] = [model[name] for model in models] + lazy_tensors: list[LazyTensor] = [model[name] for model in models] if len(lazy_tensors) == 1: # only one file; don't go through this procedure since there might # be quantized tensors @@ -570,7 +571,7 @@ def merge_sharded(models: List[LazyModel]) -> LazyModel: return {name: convert(name) for name in names} -def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus: +def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: formats = set(mp.format for mp in models_plus) assert len(formats) == 1, "different formats?" format = formats.pop() @@ -674,7 +675,7 @@ class LazyUnpickler(pickle.Unpickler): def rebuild_from_type_v2(func, new_type, args, state): return func(*args) - CLASSES: Dict[Tuple[str, str], Any] = { + CLASSES: dict[tuple[str, str], Any] = { # getattr used here as a workaround for mypy not being smart enough to detrmine # the staticmethods have a __func__ attribute. ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), @@ -707,15 +708,15 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: header_size, = struct.unpack('<Q', fp.read(8)) - header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size)) + header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size)) # Use mmap for the actual data to avoid race conditions with the file offset. mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ)) byte_buf = mapped[8 + header_size:] - def convert(info: Dict[str, Any]) -> LazyTensor: + def convert(info: dict[str, Any]) -> LazyTensor: data_type = SAFETENSORS_DATA_TYPES[info['dtype']] numpy_dtype = data_type.dtype - shape: List[int] = info['shape'] + shape: list[int] = info['shape'] begin, end = info['data_offsets'] assert 0 <= begin <= end <= len(byte_buf) assert end - begin == math.prod(shape) * numpy_dtype.itemsize @@ -754,7 +755,7 @@ def lazy_load_file(path: Path) -> ModelPlus: In = TypeVar('In') Out = TypeVar('Out') -def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]: +def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]: '''Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one @@ -763,13 +764,13 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield from map(func, iterable) # Not reached. iterable = iter(iterable) - executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]] + executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] if use_processpool_executor: executor_class = ProcessPoolExecutor else: executor_class = ThreadPoolExecutor with executor_class(max_workers = max_workers) as executor: - futures: List[concurrent.futures.Future[Out]] = [] + futures: list[concurrent.futures.Future[Out]] = [] done = False for _ in range(concurrency): try: @@ -893,13 +894,13 @@ class OutputFile: of.close() @staticmethod - def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]: + def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]: name, lazy_tensor = item tensor = lazy_tensor.load().to_ggml() return (lazy_tensor.data_type, tensor.ndarray) @staticmethod - def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray: + def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: dt, arr = item if not isinstance(dt, QuantizedDataType): return arr @@ -940,7 +941,7 @@ class OutputFile: of.close() -def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: +def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): @@ -960,7 +961,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM def convert_model_names(model: LazyModel, params: Params) -> LazyModel: tmap = gguf.TensorNameMap(ARCH, params.n_layer) - should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) tmp = model @@ -995,12 +996,12 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: return out -def nth_multifile_path(path: Path, n: int) -> Optional[Path]: +def nth_multifile_path(path: Path, n: int) -> Path | None: '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return the nth path in the model. ''' # Support the following patterns: - patterns: List[Tuple[str, str]] = [ + patterns: list[tuple[str, str]] = [ # - x.00.pth, x.01.pth, etc. (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. @@ -1016,11 +1017,11 @@ def nth_multifile_path(path: Path, n: int) -> Optional[Path]: return None -def find_multifile_paths(path: Path) -> List[Path]: +def find_multifile_paths(path: Path) -> list[Path]: '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return the whole list of paths in the model. ''' - ret: List[Path] = [] + ret: list[Path] = [] for i in itertools.count(): nth_path = nth_multifile_path(path, i) if nth_path is None: @@ -1051,7 +1052,7 @@ def load_some_model(path: Path) -> ModelPlus: path = files[0] paths = find_multifile_paths(path) - models_plus: List[ModelPlus] = [] + models_plus: list[ModelPlus] = [] for path in paths: print(f"Loading model file {path}") models_plus.append(lazy_load_file(path)) @@ -1060,7 +1061,7 @@ def load_some_model(path: Path) -> ModelPlus: return model_plus -def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]: +def load_vocab(path: Path, vocabtype: str | None) -> Vocab: # Be extra-friendly and accept either a file or a directory. Also, if it's # a directory, it might be the model directory, and tokenizer.model might # be in the parent of that. @@ -1091,7 +1092,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence raise ValueError(f"Unsupported vocabulary type {vocabtype}") -def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: +def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: namestr = { GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", @@ -1114,7 +1115,7 @@ def do_dump_model(model_plus: ModelPlus) -> None: print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") -def main(args_in: Optional[List[str]] = None) -> None: +def main(args_in: list[str] | None = None) -> None: parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") |