diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-08-23 23:08:04 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 23:08:04 +0300 |
commit | cf658adc832badaaa2ca119fe86070e5a830f8f6 (patch) | |
tree | e314db2fb18676067ddbc5cde0cf7f73c417af29 /convert-falcon-hf-to-gguf.py | |
parent | a192860cfec89a38d59a943623bf595b1fe4495b (diff) |
llm : add Falcon support (#2717)
* llama : refactor GGUF constants into static maps
* llama : check if model architecture is known
* llama : refactor llama_model_load_internal()
* gguf : add KV constant maps
* llm : read arch-specific KVs
* convert : add dummy scores + types
* falcon : load tensor data (CPU only)
* llama : fix loading progress bar
* llama : add arch member to llama_model
* falcon : CPU inference working
* falcon : support non-40B models
* falcon : minor
* llama : minor updates
ggml-ci
* convert-falcon-hf-to-gguf.py : fix special token mapping
* llama.cpp : llama default UNK token = id 0
* llama.cpp : fix bpe tokenizer
* llama.cpp : fix the fix of bpe tokenizer
* ggml : pass eps to ggml_norm
* metal : implement RoPE (mode = 2) + avoid ggml_repeat
* ggml : ggml_repeat always creates new tensor
* falcon : copy-paste self-attention from LLaMA
* metal : print extra compute pipeline info
* falcon : minor changes (still chasing the Metal problem)
* llama.cpp : fix linefeed token
* metal : fix GELU kernel numerical stability by using precise::tanh
* metal : temporary workaround for the concurrency optimization bug
* falcon : add CUDA offloading (#2739)
* llama : better model naming and size reporting
* llama : prep new tokenizer support
* llama : advanced BPE tokenizer based on ggllm.cpp imlpementation
* llama : remove oboslete comment
ggml-ci
* common : remove obsolete BPE API + disable test-tokenizer-1
* llama : revert BPE special-case in llama_byte_to_token()
* cuda : add TODOs for RoPE NeoX implementation
* llama : default special tokens based on vocab type
* perplexity : add log for start of tokenization
---------
Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'convert-falcon-hf-to-gguf.py')
-rwxr-xr-x | convert-falcon-hf-to-gguf.py | 55 |
1 files changed, 25 insertions, 30 deletions
diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index 50069db5..43e20849 100755 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -95,14 +95,17 @@ print("gguf: get model metadata") block_count = hparams["n_layer"] -gguf_writer.add_name(last_dir) +gguf_writer.add_name("Falcon") gguf_writer.add_context_length(2048) # not in config.json gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform gguf_writer.add_embedding_length(hparams["hidden_size"]) gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"]) gguf_writer.add_block_count(block_count) gguf_writer.add_head_count(hparams["n_head"]) -if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"]) +if "n_head_kv" in hparams: + gguf_writer.add_head_count_kv(hparams["n_head_kv"]) +else: + gguf_writer.add_head_count_kv(1) gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) # TOKENIZATION @@ -110,6 +113,8 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) print("gguf: get tokenizer metadata") tokens: List[str] = [] +scores: List[float] = [] +toktypes: List[int] = [] merges: List[str] = [] @@ -153,41 +158,30 @@ if Path(dir_model + "/tokenizer.json").is_file(): text = bytearray(pad_token) tokens.append(text) + scores.append(0.0) # dymmy + toktypes.append(gguf.TokenType.NORMAL) # dummy gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) - if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): - print("gguf: get special token ids") - - with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: - tokenizer_config = json.load(f) +print("gguf: get special token ids") +# Look for special tokens in config.json - # find special token ids +if "bos_token_id" in hparams and hparams["bos_token_id"] != None: + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) - if "bos_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["bos_token"]: - gguf_writer.add_bos_token_id(key["id"]) +if "eos_token_id" in hparams and hparams["eos_token_id"] != None: + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) - if "eos_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["eos_token"]: - gguf_writer.add_eos_token_id(key["id"]) +if "unk_token_id" in hparams and hparams["unk_token_id"] != None: + gguf_writer.add_unk_token_id(hparams["unk_token_id"]) - if "unk_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["unk_token"]: - gguf_writer.add_unk_token_id(key["id"]) +if "sep_token_id" in hparams and hparams["sep_token_id"] != None: + gguf_writer.add_sep_token_id(hparams["sep_token_id"]) - if "sep_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["sep_token"]: - gguf_writer.add_sep_token_id(key["id"]) - - if "pad_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["pad_token"]: - gguf_writer.add_pad_token_id(key["id"]) +if "pad_token_id" in hparams and hparams["pad_token_id"] != None: + gguf_writer.add_pad_token_id(hparams["pad_token_id"]) # TENSORS @@ -195,8 +189,9 @@ if Path(dir_model + "/tokenizer.json").is_file(): tensor_map = gguf.get_tensor_name_map(ARCH,block_count) # params for qkv transform -n_head = hparams["n_head"] +n_head = hparams["n_head"] n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 + head_dim = hparams["hidden_size"] // n_head # tensor info |