diff options
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-x | convert-hf-to-gguf.py | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ae471481..9771fccf 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1650,7 +1650,29 @@ class BertModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_causal_attention(False) - self.gguf_writer.add_pooling_layer(True) + + # get pooling path + with open(self.dir_model / "modules.json", encoding="utf-8") as f: + modules = json.load(f) + pooling_path = None + for mod in modules: + if mod["type"] == "sentence_transformers.models.Pooling": + pooling_path = mod["path"] + break + + # get pooling type + pooling_type = gguf.PoolingType.NONE + if pooling_path is not None: + with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + pooling = json.load(f) + if pooling["pooling_mode_mean_tokens"]: + pooling_type = gguf.PoolingType.MEAN + elif pooling["pooling_mode_cls_token"]: + pooling_type = gguf.PoolingType.CLS + else: + raise NotImplementedError("Only MEAN and CLS pooling types supported") + + self.gguf_writer.add_pooling_type(pooling_type.value) def set_vocab(self): path = self.dir_model |