summaryrefslogtreecommitdiff
path: root/convert-lora-to-ggml.py
diff options
context:
space:
mode:
authorkuronekosaiko <EvanChanJ@163.com>2024-01-22 00:28:14 +0800
committerGitHub <noreply@github.com>2024-01-21 17:28:14 +0100
commit05490fad7f7f60ff2bed9ad05cd81b44e82ccde3 (patch)
treed77e15d2b7977163c43209f79240e805629081e0 /convert-lora-to-ggml.py
parent6c5629d4d2d15557c38a0e609b30c1a42abad80d (diff)
add safetensors support to convert-lora-to-ggml.py (#5062)
* add safetensors support to convert-lora-to-ggml.py * Update convert-lora-to-ggml.py Remove white space in line 69.
Diffstat (limited to 'convert-lora-to-ggml.py')
-rwxr-xr-xconvert-lora-to-ggml.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py
index 4904bf12..9a9936de 100755
--- a/convert-lora-to-ggml.py
+++ b/convert-lora-to-ggml.py
@@ -59,7 +59,14 @@ if __name__ == '__main__':
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
- model = torch.load(input_model, map_location="cpu")
+ if os.path.exists(input_model):
+ model = torch.load(input_model, map_location="cpu")
+ else:
+ input_model = os.path.join(sys.argv[1], "adapter_model.safetensors")
+ # lazy import load_file only if lora is in safetensors format.
+ from safetensors.torch import load_file
+ model = load_file(input_model, device="cpu")
+
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
if arch_name not in gguf.MODEL_ARCH_NAMES.values():