diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-20 21:07:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 21:07:23 +0300 |
commit | d1031cf49c3b958b915fd558e23453471c29ac33 (patch) | |
tree | 14fa2bc6d54d5e27bd1e8bfd6fa4dbf894dbe6b9 /examples/embd-input/panda_gpt.py | |
parent | 8cf19d60dc93809db8e51fedc811595eed9134c5 (diff) |
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params
* llama : combine repetition, frequency and presence penalties in 1 call
* examples : remove embd-input and gptneox-wip
* sampling : rename penalty params + reduce size of "prev" vector
* sampling : add llama_sampling_print helper
* sampling : hide prev behind API and apply #3661
ggml-ci
Diffstat (limited to 'examples/embd-input/panda_gpt.py')
-rwxr-xr-x | examples/embd-input/panda_gpt.py | 99 |
1 files changed, 0 insertions, 99 deletions
diff --git a/examples/embd-input/panda_gpt.py b/examples/embd-input/panda_gpt.py deleted file mode 100755 index 891ad7cc..00000000 --- a/examples/embd-input/panda_gpt.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python3 -import sys -import os -sys.path.insert(0, os.path.dirname(__file__)) -from embd_input import MyModel -import numpy as np -from torch import nn -import torch - -# use PandaGPT path -panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT") -imagebind_ckpt_path = "./models/panda_gpt/" - -sys.path.insert(0, os.path.join(panda_gpt_path,"code","model")) -from ImageBind.models import imagebind_model -from ImageBind import data - -ModalityType = imagebind_model.ModalityType -max_tgt_len = 400 - -class PandaGPT: - def __init__(self, args): - self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path) - self.visual_encoder.eval() - self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120) - self.max_tgt_len = max_tgt_len - self.model = MyModel(["main", *args]) - self.generated_text = "" - self.device = "cpu" - - def load_projection(self, path): - state = torch.load(path, map_location="cpu") - self.llama_proj.load_state_dict({ - "weight": state["llama_proj.weight"], - "bias": state["llama_proj.bias"]}) - - def eval_inputs(self, inputs): - self.model.eval_string("<Img>") - embds = self.extract_multimoal_feature(inputs) - for i in embds: - self.model.eval_float(i.T) - self.model.eval_string("</Img> ") - - def chat(self, question): - return self.chat_with_image(None, question) - - def chat_with_image(self, inputs, question): - if self.generated_text == "": - self.model.eval_string("###") - self.model.eval_string(" Human: ") - if inputs: - self.eval_inputs(inputs) - self.model.eval_string(question) - self.model.eval_string("\n### Assistant:") - ret = self.model.generate_with_print(end="###") - self.generated_text += ret - return ret - - def extract_multimoal_feature(self, inputs): - features = [] - for key in ["image", "audio", "video", "thermal"]: - if key + "_paths" in inputs: - embeds = self.encode_data(key, inputs[key+"_paths"]) - features.append(embeds) - return features - - def encode_data(self, data_type, data_paths): - - type_map = { - "image": ModalityType.VISION, - "audio": ModalityType.AUDIO, - "video": ModalityType.VISION, - "thermal": ModalityType.THERMAL, - } - load_map = { - "image": data.load_and_transform_vision_data, - "audio": data.load_and_transform_audio_data, - "video": data.load_and_transform_video_data, - "thermal": data.load_and_transform_thermal_data - } - - load_function = load_map[data_type] - key = type_map[data_type] - - inputs = {key: load_function(data_paths, self.device)} - with torch.no_grad(): - embeddings = self.visual_encoder(inputs) - embeds = embeddings[key] - embeds = self.llama_proj(embeds).cpu().numpy() - return embeds - - -if __name__=="__main__": - a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"]) - a.load_projection("./models/panda_gpt/adapter_model.bin") - a.chat_with_image( - {"image_paths": ["./media/llama1-logo.png"]}, - "what is the text in the picture? 'llama' or 'lambda'?") - a.chat("what is the color of it?") |