summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
authorfairydreaming <166155368+fairydreaming@users.noreply.github.com>2024-05-23 11:49:53 +0200
committerGitHub <noreply@github.com>2024-05-23 11:49:53 +0200
commit9b82476ee9e73065a759f8bcc4cf27ec7ab2ed8c (patch)
treed4881d12bc7e60750f90e642e3fabbdf4029fc53 /convert-hf-to-gguf.py
parenta61a94e543e3c6877c087e80fca27a0313ce5fd5 (diff)
Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models) (#7461)
* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel) * llama : add inference support for LLM_ARCH_GPTNEOX * llama : add model types for every Pythia variant and GPT-NeoX Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index daad1c4f..5a00a5e8 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -673,6 +673,44 @@ class GPTNeoXModel(Model):
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+
+ n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
+ n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
+
+ tensors: list[tuple[str, Tensor]] = []
+
+ if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
+ # Map bloom-style qkv_linear to gpt-style qkv_linear
+ # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
+ # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
+ qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
+ data_torch = torch.cat(
+ (
+ qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
+ qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
+ qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
+ ),
+ dim=0,
+ )
+ logger.info("re-format attention.linear_qkv.weight")
+ elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
+ qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
+ data_torch = torch.cat(
+ (
+ qkv_bias[:, 0, :].reshape((n_embed,)),
+ qkv_bias[:, 1, :].reshape((n_embed,)),
+ qkv_bias[:, 2, :].reshape((n_embed,)),
+ ),
+ dim=0,
+ )
+ logger.info("re-format attention.linear_qkv.bias")
+
+ tensors.append((self.map_tensor_name(name), data_torch))
+
+ return tensors
+
@Model.register("BloomForCausalLM")
class BloomModel(Model):