diff options
author | Douglas Hanley <thesecretaryofwar@gmail.com> | 2024-02-15 11:21:49 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-15 12:21:49 -0500 |
commit | 4524290e87b8e107cc2b56e1251751546f4b9051 (patch) | |
tree | 38d50aa2850bc2ecb53619fb9f03e0f91953a4c6 /gguf-py | |
parent | c06e45d72983d9ace7b1535f7e7ea258d212169e (diff) |
Use correct type of pooling for embedding models (#5500)
Use correct type of pooling for embedding models
Diffstat (limited to 'gguf-py')
-rw-r--r-- | gguf-py/gguf/constants.py | 8 | ||||
-rw-r--r-- | gguf-py/gguf/gguf_writer.py | 5 |
2 files changed, 10 insertions, 3 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9986ce9d..114a9a97 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -40,7 +40,7 @@ class Keys: TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" - POOLING_LAYER = "{arch}.pooling_layer" + POOLING_TYPE = "{arch}.pooling_type" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -561,6 +561,12 @@ class RopeScalingType(Enum): YARN = 'yarn' +class PoolingType(IntEnum): + NONE = 0 + MEAN = 1 + CLS = 2 + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 26724bf9..e4681475 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -19,6 +19,7 @@ from .constants import ( GGUFValueType, Keys, RopeScalingType, + PoolingType, TokenType, ) @@ -360,8 +361,8 @@ class GGUFWriter: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) - def add_pooling_layer(self, value: bool) -> None: - self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: + self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value) def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) |