summaryrefslogtreecommitdiff
path: root/examples/embd-input/minigpt4.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embd-input/minigpt4.py')
-rwxr-xr-xexamples/embd-input/minigpt4.py129
1 files changed, 0 insertions, 129 deletions
diff --git a/examples/embd-input/minigpt4.py b/examples/embd-input/minigpt4.py
deleted file mode 100755
index 7b13e4a5..00000000
--- a/examples/embd-input/minigpt4.py
+++ /dev/null
@@ -1,129 +0,0 @@
-#!/usr/bin/env python3
-import sys
-import os
-sys.path.insert(0, os.path.dirname(__file__))
-from embd_input import MyModel
-import numpy as np
-from torch import nn
-import torch
-from PIL import Image
-
-minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
-sys.path.insert(0, minigpt4_path)
-from minigpt4.models.blip2 import Blip2Base
-from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
-
-
-class MiniGPT4(Blip2Base):
- """
- MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
- """
- def __init__(self,
- args,
- vit_model="eva_clip_g",
- q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
- img_size=224,
- drop_path_rate=0,
- use_grad_checkpoint=False,
- vit_precision="fp32",
- freeze_vit=True,
- freeze_qformer=True,
- num_query_token=32,
- llama_model="",
- prompt_path="",
- prompt_template="",
- max_txt_len=32,
- end_sym='\n',
- low_resource=False, # use 8 bit and put vit in cpu
- device_8bit=0
- ):
- super().__init__()
- self.img_size = img_size
- self.low_resource = low_resource
- self.preprocessor = Blip2ImageEvalProcessor(img_size)
-
- print('Loading VIT')
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(
- vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
- )
- print('Loading VIT Done')
- print('Loading Q-Former')
- self.Qformer, self.query_tokens = self.init_Qformer(
- num_query_token, self.visual_encoder.num_features
- )
- self.Qformer.cls = None
- self.Qformer.bert.embeddings.word_embeddings = None
- self.Qformer.bert.embeddings.position_embeddings = None
- for layer in self.Qformer.bert.encoder.layer:
- layer.output = None
- layer.intermediate = None
- self.load_from_pretrained(url_or_filename=q_former_model)
- print('Loading Q-Former Done')
- self.llama_proj = nn.Linear(
- self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
- )
- self.max_txt_len = max_txt_len
- self.end_sym = end_sym
- self.model = MyModel(["main", *args])
- # system prompt
- self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
- "You will be able to see the image once I provide it to you. Please answer my questions."
- "###")
-
- def encode_img(self, image):
- image = self.preprocessor(image)
- image = image.unsqueeze(0)
- device = image.device
- if self.low_resource:
- self.vit_to_cpu()
- image = image.to("cpu")
-
- with self.maybe_autocast():
- image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
-
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_output = self.Qformer.bert(
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- return_dict=True,
- )
-
- inputs_llama = self.llama_proj(query_output.last_hidden_state)
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
- return inputs_llama
-
- def load_projection(self, path):
- state = torch.load(path)["model"]
- self.llama_proj.load_state_dict({
- "weight": state["llama_proj.weight"],
- "bias": state["llama_proj.bias"]})
-
- def chat(self, question):
- self.model.eval_string("Human: ")
- self.model.eval_string(question)
- self.model.eval_string("\n### Assistant:")
- return self.model.generate_with_print(end="###")
-
- def chat_with_image(self, image, question):
- with torch.no_grad():
- embd_image = self.encode_img(image)
- embd_image = embd_image.cpu().numpy()[0]
- self.model.eval_string("Human: <Img>")
- self.model.eval_float(embd_image.T)
- self.model.eval_string("</Img> ")
- self.model.eval_string(question)
- self.model.eval_string("\n### Assistant:")
- return self.model.generate_with_print(end="###")
-
-
-if __name__=="__main__":
- a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
- a.load_projection(os.path.join(
- os.path.dirname(__file__) ,
- "pretrained_minigpt4.pth"))
- respose = a.chat_with_image(
- Image.open("./media/llama1-logo.png").convert('RGB'),
- "what is the text in the picture?")
- a.chat("what is the color of it?")