summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/llava/llava-surgery.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py
index 26294d9b..515f6b58 100644
--- a/examples/llava/llava-surgery.py
+++ b/examples/llava/llava-surgery.py
@@ -16,13 +16,29 @@ checkpoint = torch.load(path)
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
# store these tensors in a new dictionary and torch.save them
-projector = {name: checkpoint[name] for name in mm_tensors}
+projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector")
# remove these tensors from the checkpoint and save it again
for name in mm_tensors:
del checkpoint[name]
+# BakLLaVA models contain CLIP tensors in it
+clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
+if len(clip_tensors) > 0:
+ clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
+ torch.save(clip, f"{args.model}/llava.clip")
+
+ # remove these tensors
+ for name in clip_tensors:
+ del checkpoint[name]
+
+ # added tokens should be removed to be able to convert Mistral models
+ if os.path.exists(f"{args.model}/added_tokens.json"):
+ with open(f"{args.model}/added_tokens.json", "w") as f:
+ f.write("{}\n")
+
+
torch.save(checkpoint, path)
print("Done!")