summaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-08-22 20:05:59 +0300
committerGitHub <noreply@github.com>2023-08-22 20:05:59 +0300
commitdeb7dfca4b9725cd295d1426db75fe8e0a6d5312 (patch)
treef36daf023af86b6005325cbb4ee80a7966255e59 /convert.py
parentbac66994cf356cf488078c056831396eb4ce31d5 (diff)
gguf : add ftype meta info to the model (#2710)
* llama : add ftype meta info to the model ggml-ci * convert.py : add ftype when converting (does not work) * convert.py : fix Enum to IntEnum ggml-ci
Diffstat (limited to 'convert.py')
-rw-r--r--convert.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/convert.py b/convert.py
index c29c032c..71978d67 100644
--- a/convert.py
+++ b/convert.py
@@ -69,7 +69,10 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'I32': DT_I32,
}
-class GGMLFileType(enum.Enum):
+# TODO: match this with `llama_ftype`
+# TODO: rename to LLAMAFileType
+# TODO: move to `gguf.py`
+class GGMLFileType(enum.IntEnum):
AllF32 = 0
MostlyF16 = 1 # except 1d tensors
@@ -101,6 +104,8 @@ class Params:
n_head_kv: int
f_norm_eps: float
+ ftype: Optional[GGMLFileType] = None
+
@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
@@ -738,6 +743,9 @@ class OutputFile:
self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
+ if params.ftype:
+ self.gguf.add_file_type(params.ftype)
+
def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = []
scores = []
@@ -1020,6 +1028,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
" - LLaMA v2: --ctx 4096\n")
params.n_ctx = args.ctx
+ if args.outtype:
+ params.ftype = {
+ "f32": GGMLFileType.AllF32,
+ "f16": GGMLFileType.MostlyF16,
+ }[args.outtype]
+
print(f"params = {params}")
vocab: Vocab
@@ -1040,11 +1054,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype)
- model = model_plus.model
- model = convert_model_names(model, params)
- output_type = pick_output_type(model, args.outtype)
- model = convert_to_output_type(model, output_type)
- outfile = args.outfile or default_outfile(model_plus.paths, output_type)
+ model = model_plus.model
+ model = convert_model_names(model, params)
+ ftype = pick_output_type(model, args.outtype)
+ model = convert_to_output_type(model, ftype)
+ outfile = args.outfile or default_outfile(model_plus.paths, ftype)
+
+ params.ftype = ftype
+ print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, params, model, vocab)
print(f"Wrote {outfile}")