minigpt4.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import sys
  2. import os
  3. sys.path.insert(0, os.path.dirname(__file__))
  4. from embd_input import MyModel
  5. import numpy as np
  6. from torch import nn
  7. import torch
  8. from PIL import Image
  9. minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
  10. sys.path.insert(0, minigpt4_path)
  11. from minigpt4.models.blip2 import Blip2Base
  12. from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
  13. class MiniGPT4(Blip2Base):
  14. """
  15. MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
  16. """
  17. def __init__(self,
  18. args,
  19. vit_model="eva_clip_g",
  20. q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
  21. img_size=224,
  22. drop_path_rate=0,
  23. use_grad_checkpoint=False,
  24. vit_precision="fp32",
  25. freeze_vit=True,
  26. freeze_qformer=True,
  27. num_query_token=32,
  28. llama_model="",
  29. prompt_path="",
  30. prompt_template="",
  31. max_txt_len=32,
  32. end_sym='\n',
  33. low_resource=False, # use 8 bit and put vit in cpu
  34. device_8bit=0
  35. ):
  36. super().__init__()
  37. self.img_size = img_size
  38. self.low_resource = low_resource
  39. self.preprocessor = Blip2ImageEvalProcessor(img_size)
  40. print('Loading VIT')
  41. self.visual_encoder, self.ln_vision = self.init_vision_encoder(
  42. vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
  43. )
  44. print('Loading VIT Done')
  45. print('Loading Q-Former')
  46. self.Qformer, self.query_tokens = self.init_Qformer(
  47. num_query_token, self.visual_encoder.num_features
  48. )
  49. self.Qformer.cls = None
  50. self.Qformer.bert.embeddings.word_embeddings = None
  51. self.Qformer.bert.embeddings.position_embeddings = None
  52. for layer in self.Qformer.bert.encoder.layer:
  53. layer.output = None
  54. layer.intermediate = None
  55. self.load_from_pretrained(url_or_filename=q_former_model)
  56. print('Loading Q-Former Done')
  57. self.llama_proj = nn.Linear(
  58. self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
  59. )
  60. self.max_txt_len = max_txt_len
  61. self.end_sym = end_sym
  62. self.model = MyModel(["main", *args])
  63. # system prompt
  64. self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
  65. "You will be able to see the image once I provide it to you. Please answer my questions."
  66. "###")
  67. def encode_img(self, image):
  68. image = self.preprocessor(image)
  69. image = image.unsqueeze(0)
  70. device = image.device
  71. if self.low_resource:
  72. self.vit_to_cpu()
  73. image = image.to("cpu")
  74. with self.maybe_autocast():
  75. image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
  76. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
  77. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  78. query_output = self.Qformer.bert(
  79. query_embeds=query_tokens,
  80. encoder_hidden_states=image_embeds,
  81. encoder_attention_mask=image_atts,
  82. return_dict=True,
  83. )
  84. inputs_llama = self.llama_proj(query_output.last_hidden_state)
  85. # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
  86. return inputs_llama
  87. def load_projection(self, path):
  88. state = torch.load(path)["model"]
  89. self.llama_proj.load_state_dict({
  90. "weight": state["llama_proj.weight"],
  91. "bias": state["llama_proj.bias"]})
  92. def chat(self, question):
  93. self.model.eval_string("Human: ")
  94. self.model.eval_string(question)
  95. self.model.eval_string("\n### Assistant:")
  96. return self.model.generate_with_print(end="###")
  97. def chat_with_image(self, image, question):
  98. with torch.no_grad():
  99. embd_image = self.encode_img(image)
  100. embd_image = embd_image.cpu().numpy()[0]
  101. self.model.eval_string("Human: <Img>")
  102. self.model.eval_float(embd_image.T)
  103. self.model.eval_string("</Img> ")
  104. self.model.eval_string(question)
  105. self.model.eval_string("\n### Assistant:")
  106. return self.model.generate_with_print(end="###")
  107. if __name__=="__main__":
  108. a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
  109. a.load_projection(os.path.join(
  110. os.path.dirname(__file__) ,
  111. "pretrained_minigpt4.pth"))
  112. respose = a.chat_with_image(
  113. Image.open("./media/llama1-logo.png").convert('RGB'),
  114. "what is the text in the picture?")
  115. a.chat("what is the color of it?")