summaryrefslogtreecommitdiff
path: root/examples/llava/llava_surgery_v2.py
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-07-27 07:55:01 +0200
committerGitHub <noreply@github.com>2024-07-27 07:55:01 +0200
commit154e0d75fccf1784fe9ff6fd76a630b66563da3d (patch)
tree81ce6dbb5b1900c1aa78a879f0593c694cab9d27 /examples/llava/llava_surgery_v2.py
parent0684c3e9c70d49323b4fc517128cbe222cab7f96 (diff)
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples/llava/llava_surgery_v2.py')
-rw-r--r--examples/llava/llava_surgery_v2.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py
new file mode 100644
index 00000000..2d5b32fe
--- /dev/null
+++ b/examples/llava/llava_surgery_v2.py
@@ -0,0 +1,159 @@
+import argparse
+import glob
+import os
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+from typing import Any, ContextManager, cast
+
+# Function to determine if file is a SafeTensor file
+def is_safetensor_file(file_path):
+ return file_path.endswith('.safetensors')
+
+
+# Unified loading function
+def load_model(file_path):
+ if is_safetensor_file(file_path):
+ tensors = {}
+ with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
+ for key in f.keys():
+ tensors[key] = f.get_tensor(key).clone()
+ # output shape
+ print(f"{key} : {tensors[key].shape}")
+ return tensors, 'safetensor'
+ else:
+ return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
+
+
+# Unified saving function
+def save_model(model, file_path, file_type):
+ if file_type == 'safetensor':
+ # safe_save(model, file_path)
+ save_file(model, file_path)
+ else:
+ torch.save(model, file_path)
+
+
+# Adapted function to clean vision tower from checkpoint
+def clean_vision_tower_from_checkpoint(checkpoint_path):
+ checkpoint, file_type = load_model(checkpoint_path)
+ # file_type = 'pytorch'
+ model_path = os.path.dirname(checkpoint_path)
+ print(f"Searching for vision tower tensors in {checkpoint_path}")
+ clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
+
+ if len(clip_tensors) > 0:
+ print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
+ # Adapted for file type
+ clip_path = os.path.join(model_path, "llava.clip")
+
+ if os.path.exists(clip_path):
+ print(f"Loading existing llava.clip from {clip_path}")
+ existing_clip, _ = load_model(clip_path)
+ else:
+ print(f"Creating new llava.clip at {clip_path}")
+ existing_clip = {}
+ # Update existing_clip with new tensors, avoid duplicates
+ for name in clip_tensors:
+ simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
+ print(f"Adding {simple_name} to llava.clip")
+ if simple_name not in existing_clip:
+ existing_clip[simple_name] = checkpoint[name]
+
+ # Save the updated clip tensors back to llava.clip
+ save_model(existing_clip, clip_path, 'pytorch')
+
+ # Remove the tensors from the original checkpoint
+ for name in clip_tensors:
+ del checkpoint[name]
+
+ checkpoint_path = checkpoint_path
+ return True
+ return False
+
+def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
+ newline_checkpoint_path = None
+ projector_checkpoint_path = None
+
+ for path in checkpoint_paths:
+ checkpoint, _ = load_model(path)
+ if newline_criteria(checkpoint) and newline_checkpoint_path is None:
+ newline_checkpoint_path = path
+ if projector(checkpoint):
+ projector_checkpoint_path = path
+
+ return newline_checkpoint_path, projector_checkpoint_path
+
+def newline_criteria(checkpoint):
+ return any(k.startswith("model.image_newline") for k in checkpoint.keys())
+
+def proj_criteria(checkpoint):
+ return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
+
+
+# Command-line interface setup
+ap = argparse.ArgumentParser()
+ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
+ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
+args = ap.parse_args()
+
+if args.clean_vision_tower:
+ # Generalized to handle both PyTorch and SafeTensors models
+ model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
+ # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
+ checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
+ for projector_checkpoint_path in checkpoint_paths:
+ print(f"Cleaning {projector_checkpoint_path}")
+ if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
+ print(f"No vision tower found in {projector_checkpoint_path}")
+ # we break once none is found, so far all models append them at the end
+ # break
+ print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
+
+# Now we look for the projector in the last checkpoint
+model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
+checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
+# last_checkpoint_path = checkpoint_paths[0]
+# first_checkpoint_path = checkpoint_paths[-1]
+newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
+
+print(f"Taking projector from {projector_checkpoint_path}")
+first_mm_tensors = []
+first_checkpoint = None
+if newline_checkpoint_path is not None:
+ print(f"Taking newline from {newline_checkpoint_path}")
+ first_checkpoint, file_type = load_model(newline_checkpoint_path)
+ first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
+
+# Load the checkpoint
+mm_tensors = []
+last_checkpoint = None
+if projector_checkpoint_path is not None:
+ last_checkpoint, file_type = load_model(projector_checkpoint_path)
+ mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
+
+if len(mm_tensors) == 0:
+ if last_checkpoint is not None:
+ for k, v in last_checkpoint.items():
+ print(k)
+ print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
+ print("No tensors found. Is this a LLaVA model?")
+ exit()
+
+print(f"Found {len(mm_tensors)} tensors to extract.")
+print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
+# projector = {name: checkpoint.[name].float() for name in mm_tensors}
+projector = {}
+for name in mm_tensors:
+ assert last_checkpoint is not None
+ projector[name] = last_checkpoint[name].float()
+for name in first_mm_tensors:
+ assert first_checkpoint is not None
+ projector[name] = first_checkpoint[name].float()
+
+if len(projector) > 0:
+ save_model(projector, f"{args.model}/llava.projector", 'pytorch')
+
+print("Done!")
+print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
+print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")