summaryrefslogtreecommitdiff
path: root/examples/embd-input/minigpt4.py
blob: 7b13e4a5cc4f8e4c61a5d48fcdffb3871c487fc2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
from PIL import Image

minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
sys.path.insert(0, minigpt4_path)
from minigpt4.models.blip2 import Blip2Base
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor


class MiniGPT4(Blip2Base):
    """
    MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
    """
    def __init__(self,
        args,
        vit_model="eva_clip_g",
        q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp32",
        freeze_vit=True,
        freeze_qformer=True,
        num_query_token=32,
        llama_model="",
        prompt_path="",
        prompt_template="",
        max_txt_len=32,
        end_sym='\n',
        low_resource=False,  # use 8 bit and put vit in cpu
        device_8bit=0
    ):
        super().__init__()
        self.img_size = img_size
        self.low_resource = low_resource
        self.preprocessor = Blip2ImageEvalProcessor(img_size)

        print('Loading VIT')
        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        )
        print('Loading VIT Done')
        print('Loading Q-Former')
        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, self.visual_encoder.num_features
        )
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        self.load_from_pretrained(url_or_filename=q_former_model)
        print('Loading Q-Former Done')
        self.llama_proj = nn.Linear(
            self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
        )
        self.max_txt_len = max_txt_len
        self.end_sym = end_sym
        self.model = MyModel(["main", *args])
        # system prompt
        self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
           "You will be able to see the image once I provide it to you. Please answer my questions."
           "###")

    def encode_img(self, image):
        image = self.preprocessor(image)
        image = image.unsqueeze(0)
        device = image.device
        if self.low_resource:
            self.vit_to_cpu()
            image = image.to("cpu")

        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )

            inputs_llama = self.llama_proj(query_output.last_hidden_state)
            # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama

    def load_projection(self, path):
        state = torch.load(path)["model"]
        self.llama_proj.load_state_dict({
            "weight": state["llama_proj.weight"],
            "bias": state["llama_proj.bias"]})

    def chat(self, question):
        self.model.eval_string("Human: ")
        self.model.eval_string(question)
        self.model.eval_string("\n### Assistant:")
        return self.model.generate_with_print(end="###")

    def chat_with_image(self, image, question):
        with torch.no_grad():
            embd_image = self.encode_img(image)
        embd_image = embd_image.cpu().numpy()[0]
        self.model.eval_string("Human: <Img>")
        self.model.eval_float(embd_image.T)
        self.model.eval_string("</Img> ")
        self.model.eval_string(question)
        self.model.eval_string("\n### Assistant:")
        return self.model.generate_with_print(end="###")


if __name__=="__main__":
    a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
    a.load_projection(os.path.join(
        os.path.dirname(__file__) ,
        "pretrained_minigpt4.pth"))
    respose = a.chat_with_image(
        Image.open("./media/llama1-logo.png").convert('RGB'),
        "what is the text in the picture?")
    a.chat("what is the color of it?")