panda_gpt.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python3
  2. import sys
  3. import os
  4. sys.path.insert(0, os.path.dirname(__file__))
  5. from embd_input import MyModel
  6. import numpy as np
  7. from torch import nn
  8. import torch
  9. # use PandaGPT path
  10. panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
  11. imagebind_ckpt_path = "./models/panda_gpt/"
  12. sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
  13. from ImageBind.models import imagebind_model
  14. from ImageBind import data
  15. ModalityType = imagebind_model.ModalityType
  16. max_tgt_len = 400
  17. class PandaGPT:
  18. def __init__(self, args):
  19. self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
  20. self.visual_encoder.eval()
  21. self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
  22. self.max_tgt_len = max_tgt_len
  23. self.model = MyModel(["main", *args])
  24. self.generated_text = ""
  25. self.device = "cpu"
  26. def load_projection(self, path):
  27. state = torch.load(path, map_location="cpu")
  28. self.llama_proj.load_state_dict({
  29. "weight": state["llama_proj.weight"],
  30. "bias": state["llama_proj.bias"]})
  31. def eval_inputs(self, inputs):
  32. self.model.eval_string("<Img>")
  33. embds = self.extract_multimoal_feature(inputs)
  34. for i in embds:
  35. self.model.eval_float(i.T)
  36. self.model.eval_string("</Img> ")
  37. def chat(self, question):
  38. return self.chat_with_image(None, question)
  39. def chat_with_image(self, inputs, question):
  40. if self.generated_text == "":
  41. self.model.eval_string("###")
  42. self.model.eval_string(" Human: ")
  43. if inputs:
  44. self.eval_inputs(inputs)
  45. self.model.eval_string(question)
  46. self.model.eval_string("\n### Assistant:")
  47. ret = self.model.generate_with_print(end="###")
  48. self.generated_text += ret
  49. return ret
  50. def extract_multimoal_feature(self, inputs):
  51. features = []
  52. for key in ["image", "audio", "video", "thermal"]:
  53. if key + "_paths" in inputs:
  54. embeds = self.encode_data(key, inputs[key+"_paths"])
  55. features.append(embeds)
  56. return features
  57. def encode_data(self, data_type, data_paths):
  58. type_map = {
  59. "image": ModalityType.VISION,
  60. "audio": ModalityType.AUDIO,
  61. "video": ModalityType.VISION,
  62. "thermal": ModalityType.THERMAL,
  63. }
  64. load_map = {
  65. "image": data.load_and_transform_vision_data,
  66. "audio": data.load_and_transform_audio_data,
  67. "video": data.load_and_transform_video_data,
  68. "thermal": data.load_and_transform_thermal_data
  69. }
  70. load_function = load_map[data_type]
  71. key = type_map[data_type]
  72. inputs = {key: load_function(data_paths, self.device)}
  73. with torch.no_grad():
  74. embeddings = self.visual_encoder(inputs)
  75. embeds = embeddings[key]
  76. embeds = self.llama_proj(embeds).cpu().numpy()
  77. return embeds
  78. if __name__=="__main__":
  79. a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
  80. a.load_projection("./models/panda_gpt/adapter_model.bin")
  81. a.chat_with_image(
  82. {"image_paths": ["./media/llama1-logo.png"]},
  83. "what is the text in the picture? 'llama' or 'lambda'?")
  84. a.chat("what is the color of it?")