summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorM. Yusuf Sarıgöz <yusufsarigoz@gmail.com>2023-10-19 19:40:41 +0300
committerGitHub <noreply@github.com>2023-10-19 19:40:41 +0300
commitf3b25e40438b3c8383caabf4e7b89863145a9f0e (patch)
tree19ebd9eb2ff9389f23cfb8eb2bfd1214f8695923 /examples
parent60abea9798f47b918a9f38c66edfd88c526d20f6 (diff)
multimodal : add BakLLaVA conversion support (#3682)
Diffstat (limited to 'examples')
-rw-r--r--examples/llava/llava-surgery.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py
index 26294d9b..515f6b58 100644
--- a/examples/llava/llava-surgery.py
+++ b/examples/llava/llava-surgery.py
@@ -16,13 +16,29 @@ checkpoint = torch.load(path)
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
# store these tensors in a new dictionary and torch.save them
-projector = {name: checkpoint[name] for name in mm_tensors}
+projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector")
# remove these tensors from the checkpoint and save it again
for name in mm_tensors:
del checkpoint[name]
+# BakLLaVA models contain CLIP tensors in it
+clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
+if len(clip_tensors) > 0:
+ clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
+ torch.save(clip, f"{args.model}/llava.clip")
+
+ # remove these tensors
+ for name in clip_tensors:
+ del checkpoint[name]
+
+ # added tokens should be removed to be able to convert Mistral models
+ if os.path.exists(f"{args.model}/added_tokens.json"):
+ with open(f"{args.model}/added_tokens.json", "w") as f:
+ f.write("{}\n")
+
+
torch.save(checkpoint, path)
print("Done!")