| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- #!/usr/bin/env python3
- import sys
- import os
- sys.path.insert(0, os.path.dirname(__file__))
- from embd_input import MyModel
- import numpy as np
- from torch import nn
- import torch
- from PIL import Image
- minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
- sys.path.insert(0, minigpt4_path)
- from minigpt4.models.blip2 import Blip2Base
- from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
- class MiniGPT4(Blip2Base):
- """
- MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
- """
- def __init__(self,
- args,
- vit_model="eva_clip_g",
- q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
- img_size=224,
- drop_path_rate=0,
- use_grad_checkpoint=False,
- vit_precision="fp32",
- freeze_vit=True,
- freeze_qformer=True,
- num_query_token=32,
- llama_model="",
- prompt_path="",
- prompt_template="",
- max_txt_len=32,
- end_sym='\n',
- low_resource=False, # use 8 bit and put vit in cpu
- device_8bit=0
- ):
- super().__init__()
- self.img_size = img_size
- self.low_resource = low_resource
- self.preprocessor = Blip2ImageEvalProcessor(img_size)
- print('Loading VIT')
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(
- vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
- )
- print('Loading VIT Done')
- print('Loading Q-Former')
- self.Qformer, self.query_tokens = self.init_Qformer(
- num_query_token, self.visual_encoder.num_features
- )
- self.Qformer.cls = None
- self.Qformer.bert.embeddings.word_embeddings = None
- self.Qformer.bert.embeddings.position_embeddings = None
- for layer in self.Qformer.bert.encoder.layer:
- layer.output = None
- layer.intermediate = None
- self.load_from_pretrained(url_or_filename=q_former_model)
- print('Loading Q-Former Done')
- self.llama_proj = nn.Linear(
- self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
- )
- self.max_txt_len = max_txt_len
- self.end_sym = end_sym
- self.model = MyModel(["main", *args])
- # system prompt
- self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
- "You will be able to see the image once I provide it to you. Please answer my questions."
- "###")
- def encode_img(self, image):
- image = self.preprocessor(image)
- image = image.unsqueeze(0)
- device = image.device
- if self.low_resource:
- self.vit_to_cpu()
- image = image.to("cpu")
- with self.maybe_autocast():
- image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_output = self.Qformer.bert(
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- return_dict=True,
- )
- inputs_llama = self.llama_proj(query_output.last_hidden_state)
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
- return inputs_llama
- def load_projection(self, path):
- state = torch.load(path)["model"]
- self.llama_proj.load_state_dict({
- "weight": state["llama_proj.weight"],
- "bias": state["llama_proj.bias"]})
- def chat(self, question):
- self.model.eval_string("Human: ")
- self.model.eval_string(question)
- self.model.eval_string("\n### Assistant:")
- return self.model.generate_with_print(end="###")
- def chat_with_image(self, image, question):
- with torch.no_grad():
- embd_image = self.encode_img(image)
- embd_image = embd_image.cpu().numpy()[0]
- self.model.eval_string("Human: <Img>")
- self.model.eval_float(embd_image.T)
- self.model.eval_string("</Img> ")
- self.model.eval_string(question)
- self.model.eval_string("\n### Assistant:")
- return self.model.generate_with_print(end="###")
- if __name__=="__main__":
- a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
- a.load_projection(os.path.join(
- os.path.dirname(__file__) ,
- "pretrained_minigpt4.pth"))
- respose = a.chat_with_image(
- Image.open("./media/llama1-logo.png").convert('RGB'),
- "what is the text in the picture?")
- a.chat("what is the color of it?")
|