summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
authorpostmasters <namnguyen@google.com>2024-01-02 03:51:28 -0800
committerGitHub <noreply@github.com>2024-01-02 13:51:28 +0200
commit83e633c27efdf0eb0ba54249e784b0ea760b1007 (patch)
tree30711187d9551899c546f9181f00456481873679 /gguf-py
parent32866c5edde402f42ff4233bb89dcfcede34fd22 (diff)
llama : differentiate the KV dims in the attention (#4657)
* Add n_key_dim and n_value_dim Some models use values that are not derived from `n_embd`. Also remove `n_embd_head` and `n_embd_gqa` because it is not clear which "head" is referred to (key or value). Fix issue #4648. * Fix `llm_build_kqv` to use `n_value_gqa` * Rebase * Rename variables * Fix llm_build_kqv to be more generic wrt n_embd_head_k * Update default values for n_embd_head_k and n_embd_head_v Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix llm_load_tensors: the asserts were not backcompat --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/constants.py2
-rw-r--r--gguf-py/gguf/gguf_writer.py6
2 files changed, 8 insertions, 0 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index ae62cc57..f0a1c51f 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -46,6 +46,8 @@ class Keys:
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
+ KEY_LENGTH = "{arch}.attention.key_length"
+ VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 73e02160..d93aaa87 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -333,6 +333,12 @@ class GGUFWriter:
def add_head_count_kv(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+ def add_key_length(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
+
+ def add_value_length(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
+
def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)