diff options
author | Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> | 2023-10-20 06:19:40 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 14:19:40 +0300 |
commit | 8cf19d60dc93809db8e51fedc811595eed9134c5 (patch) | |
tree | 879c1861fb50748c02ec031a1dcc3f6e732ca366 /convert.py | |
parent | a0edf73bda31c7c4e649e6f07c6fd30a729929cd (diff) |
gguf : support big endian platform (#3552)
* check whether platform is 390x if yes->do not import immintrin.h
* support s390x big endian
* support --bigendian option for s390x
1. verified with baichuan7b-chat with float 16 on s390x
2. verified with baichuan7b-chat
3. verified with chinese-alpaca-2-13b-f16
* update format based on editor-config checker result
* Update convert-baichuan-hf-to-gguf.py
* 1. check in ggml.c if endianess is not match
2. update GGUF version
3. change get_pack_prefix to property
4. update information log
* always use "GGUF" as beginng of GGUF file
* Compare "GGUF" with file header char by char
1. Set GGUF_MAGIC to "GGUF" string instead of int value
2. Compare "GGUF" char by char to ensure its byte order
3. Move bytes swap code from convert.py to gguf.py write_tensor_data
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'convert.py')
-rwxr-xr-x | convert.py | 20 |
1 files changed, 12 insertions, 8 deletions
@@ -803,8 +803,8 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: class OutputFile: - def __init__(self, fname_out: Path) -> None: - self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -875,10 +875,10 @@ class OutputFile: self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None: + def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: check_vocab_size(params, vocab) - of = OutputFile(fname_out) + of = OutputFile(fname_out, endianess=endianess) # meta data of.add_meta_arch(params) @@ -903,10 +903,10 @@ class OutputFile: return dt.quantize(arr) @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess=gguf.GGUFEndian.LITTLE) -> None: check_vocab_size(params, vocab) - of = OutputFile(fname_out) + of = OutputFile(fname_out, endianess=endianess) # meta data of.add_meta_arch(params) @@ -1123,8 +1123,9 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) - args = parser.parse_args(args_in) + parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + args = parser.parse_args(args_in) if args.dump_single: model_plus = lazy_load_file(args.model) do_dump_model(model_plus) @@ -1138,6 +1139,9 @@ def main(args_in: list[str] | None = None) -> None: if args.dump: do_dump_model(model_plus) return + endianess = gguf.GGUFEndian.LITTLE + if args.bigendian: + endianess = gguf.GGUFEndian.BIG params = Params.load(model_plus) if params.n_ctx == -1: @@ -1185,7 +1189,7 @@ def main(args_in: list[str] | None = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency) + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) print(f"Wrote {outfile}") |