minigpt4.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #!/usr/bin/env python3
  2. import sys
  3. import os
  4. sys.path.insert(0, os.path.dirname(__file__))
  5. from embd_input import MyModel
  6. import numpy as np
  7. from torch import nn
  8. import torch
  9. from PIL import Image
  10. minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
  11. sys.path.insert(0, minigpt4_path)
  12. from minigpt4.models.blip2 import Blip2Base
  13. from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
  14. class MiniGPT4(Blip2Base):
  15. """
  16. MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
  17. """
  18. def __init__(self,
  19. args,
  20. vit_model="eva_clip_g",
  21. q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
  22. img_size=224,
  23. drop_path_rate=0,
  24. use_grad_checkpoint=False,
  25. vit_precision="fp32",
  26. freeze_vit=True,
  27. freeze_qformer=True,
  28. num_query_token=32,
  29. llama_model="",
  30. prompt_path="",
  31. prompt_template="",
  32. max_txt_len=32,
  33. end_sym='\n',
  34. low_resource=False, # use 8 bit and put vit in cpu
  35. device_8bit=0
  36. ):
  37. super().__init__()
  38. self.img_size = img_size
  39. self.low_resource = low_resource
  40. self.preprocessor = Blip2ImageEvalProcessor(img_size)
  41. print('Loading VIT')
  42. self.visual_encoder, self.ln_vision = self.init_vision_encoder(
  43. vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
  44. )
  45. print('Loading VIT Done')
  46. print('Loading Q-Former')
  47. self.Qformer, self.query_tokens = self.init_Qformer(
  48. num_query_token, self.visual_encoder.num_features
  49. )
  50. self.Qformer.cls = None
  51. self.Qformer.bert.embeddings.word_embeddings = None
  52. self.Qformer.bert.embeddings.position_embeddings = None
  53. for layer in self.Qformer.bert.encoder.layer:
  54. layer.output = None
  55. layer.intermediate = None
  56. self.load_from_pretrained(url_or_filename=q_former_model)
  57. print('Loading Q-Former Done')
  58. self.llama_proj = nn.Linear(
  59. self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
  60. )
  61. self.max_txt_len = max_txt_len
  62. self.end_sym = end_sym
  63. self.model = MyModel(["main", *args])
  64. # system prompt
  65. self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
  66. "You will be able to see the image once I provide it to you. Please answer my questions."
  67. "###")
  68. def encode_img(self, image):
  69. image = self.preprocessor(image)
  70. image = image.unsqueeze(0)
  71. device = image.device
  72. if self.low_resource:
  73. self.vit_to_cpu()
  74. image = image.to("cpu")
  75. with self.maybe_autocast():
  76. image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
  77. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
  78. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  79. query_output = self.Qformer.bert(
  80. query_embeds=query_tokens,
  81. encoder_hidden_states=image_embeds,
  82. encoder_attention_mask=image_atts,
  83. return_dict=True,
  84. )
  85. inputs_llama = self.llama_proj(query_output.last_hidden_state)
  86. # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
  87. return inputs_llama
  88. def load_projection(self, path):
  89. state = torch.load(path)["model"]
  90. self.llama_proj.load_state_dict({
  91. "weight": state["llama_proj.weight"],
  92. "bias": state["llama_proj.bias"]})
  93. def chat(self, question):
  94. self.model.eval_string("Human: ")
  95. self.model.eval_string(question)
  96. self.model.eval_string("\n### Assistant:")
  97. return self.model.generate_with_print(end="###")
  98. def chat_with_image(self, image, question):
  99. with torch.no_grad():
  100. embd_image = self.encode_img(image)
  101. embd_image = embd_image.cpu().numpy()[0]
  102. self.model.eval_string("Human: <Img>")
  103. self.model.eval_float(embd_image.T)
  104. self.model.eval_string("</Img> ")
  105. self.model.eval_string(question)
  106. self.model.eval_string("\n### Assistant:")
  107. return self.model.generate_with_print(end="###")
  108. if __name__=="__main__":
  109. a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
  110. a.load_projection(os.path.join(
  111. os.path.dirname(__file__) ,
  112. "pretrained_minigpt4.pth"))
  113. respose = a.chat_with_image(
  114. Image.open("./media/llama1-logo.png").convert('RGB'),
  115. "what is the text in the picture?")
  116. a.chat("what is the color of it?")