summaryrefslogtreecommitdiff
path: root/examples/embd-input/llava.py
diff options
context:
space:
mode:
authorningshanwutuobang <ningshanwutuobang@gmail.com>2023-06-28 23:53:37 +0800
committerGitHub <noreply@github.com>2023-06-28 18:53:37 +0300
commitcfa0750bc9dbc2d957a91b8ed09ab0035d8f3d4e (patch)
treec8d6d6e6548d4f03899704f64bce6939e471e4e6 /examples/embd-input/llava.py
parent9d23589d638dc74577d5ff880e6d4248b795f12e (diff)
llama : support input embeddings directly (#1910)
* add interface for float input * fixed inpL shape and type * add examples of input floats * add test example for embd input * fixed sampling * add free for context * fixed add end condition for generating * add examples for llava.py * add READMD for llava.py * add READMD for llava.py * add example of PandaGPT * refactor the interface and fixed the styles * add cmake build for embd-input * add cmake build for embd-input * Add MiniGPT-4 example * change the order of the args of llama_eval_internal * fix ci error
Diffstat (limited to 'examples/embd-input/llava.py')
-rw-r--r--examples/embd-input/llava.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/examples/embd-input/llava.py b/examples/embd-input/llava.py
new file mode 100644
index 00000000..2f20cb72
--- /dev/null
+++ b/examples/embd-input/llava.py
@@ -0,0 +1,70 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__))
+from embd_input import MyModel
+import numpy as np
+from torch import nn
+import torch
+from transformers import CLIPVisionModel, CLIPImageProcessor
+from PIL import Image
+
+# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
+vision_tower = "openai/clip-vit-large-patch14"
+select_hidden_state_layer = -2
+# (vision_config.image_size // vision_config.patch_size) ** 2
+image_token_len = (224//14)**2
+
+class Llava:
+ def __init__(self, args):
+ self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+ self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
+ self.mm_projector = nn.Linear(1024, 5120)
+ self.model = MyModel(["main", *args])
+
+ def load_projection(self, path):
+ state = torch.load(path)
+ self.mm_projector.load_state_dict({
+ "weight": state["model.mm_projector.weight"],
+ "bias": state["model.mm_projector.bias"]})
+
+ def chat(self, question):
+ self.model.eval_string("user: ")
+ self.model.eval_string(question)
+ self.model.eval_string("\nassistant: ")
+ return self.model.generate_with_print()
+
+ def chat_with_image(self, image, question):
+ with torch.no_grad():
+ embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
+ image_feature = select_hidden_state[:, 1:]
+ embd_image = self.mm_projector(image_feature)
+ embd_image = embd_image.cpu().numpy()[0]
+ self.model.eval_string("user: ")
+ self.model.eval_token(32003-2) # im_start
+ self.model.eval_float(embd_image.T)
+ for i in range(image_token_len-embd_image.shape[0]):
+ self.model.eval_token(32003-3) # im_patch
+ self.model.eval_token(32003-1) # im_end
+ self.model.eval_string(question)
+ self.model.eval_string("\nassistant: ")
+ return self.model.generate_with_print()
+
+
+if __name__=="__main__":
+ # model form liuhaotian/LLaVA-13b-delta-v1-1
+ a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
+ # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
+ # Also here can use pytorch_model-00003-of-00003.bin directly.
+ a.load_projection(os.path.join(
+ os.path.dirname(__file__) ,
+ "llava_projetion.pth"))
+ respose = a.chat_with_image(
+ Image.open("./media/llama1-logo.png").convert('RGB'),
+ "what is the text in the picture?")
+ respose
+ a.chat("what is the color of it?")
+
+
+