diff options
author | ningshanwutuobang <ningshanwutuobang@gmail.com> | 2023-06-28 23:53:37 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-28 18:53:37 +0300 |
commit | cfa0750bc9dbc2d957a91b8ed09ab0035d8f3d4e (patch) | |
tree | c8d6d6e6548d4f03899704f64bce6939e471e4e6 /examples/embd-input/minigpt4.py | |
parent | 9d23589d638dc74577d5ff880e6d4248b795f12e (diff) |
llama : support input embeddings directly (#1910)
* add interface for float input
* fixed inpL shape and type
* add examples of input floats
* add test example for embd input
* fixed sampling
* add free for context
* fixed add end condition for generating
* add examples for llava.py
* add READMD for llava.py
* add READMD for llava.py
* add example of PandaGPT
* refactor the interface and fixed the styles
* add cmake build for embd-input
* add cmake build for embd-input
* Add MiniGPT-4 example
* change the order of the args of llama_eval_internal
* fix ci error
Diffstat (limited to 'examples/embd-input/minigpt4.py')
-rw-r--r-- | examples/embd-input/minigpt4.py | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/examples/embd-input/minigpt4.py b/examples/embd-input/minigpt4.py new file mode 100644 index 00000000..8e98f851 --- /dev/null +++ b/examples/embd-input/minigpt4.py @@ -0,0 +1,128 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch +from PIL import Image + +minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4") +sys.path.insert(0, minigpt4_path) +from minigpt4.models.blip2 import Blip2Base +from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor + + +class MiniGPT4(Blip2Base): + """ + MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4 + """ + def __init__(self, + args, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp32", + freeze_vit=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0 + ): + super().__init__() + self.img_size = img_size + self.low_resource = low_resource + self.preprocessor = Blip2ImageEvalProcessor(img_size) + + print('Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + print('Loading VIT Done') + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + print('Loading Q-Former Done') + self.llama_proj = nn.Linear( + self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size + ) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.model = MyModel(["main", *args]) + # system promt + self.model.eval_string("Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions." + "###") + + def encode_img(self, image): + image = self.preprocessor(image) + image = image.unsqueeze(0) + device = image.device + if self.low_resource: + self.vit_to_cpu() + image = image.to("cpu") + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama + + def load_projection(self, path): + state = torch.load(path)["model"] + self.llama_proj.load_state_dict({ + "weight": state["llama_proj.weight"], + "bias": state["llama_proj.bias"]}) + + def chat(self, question): + self.model.eval_string("Human: ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + def chat_with_image(self, image, question): + with torch.no_grad(): + embd_image = self.encode_img(image) + embd_image = embd_image.cpu().numpy()[0] + self.model.eval_string("Human: <Img>") + self.model.eval_float(embd_image.T) + self.model.eval_string("</Img> ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + +if __name__=="__main__": + a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"]) + a.load_projection(os.path.join( + os.path.dirname(__file__) , + "pretrained_minigpt4.pth")) + respose = a.chat_with_image( + Image.open("./media/llama1-logo.png").convert('RGB'), + "what is the text in the picture?") + a.chat("what is the color of it?") |