llava.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #!/usr/bin/env python3
  2. import sys
  3. import os
  4. sys.path.insert(0, os.path.dirname(__file__))
  5. from embd_input import MyModel
  6. import numpy as np
  7. from torch import nn
  8. import torch
  9. from transformers import CLIPVisionModel, CLIPImageProcessor
  10. from PIL import Image
  11. # model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
  12. vision_tower = "openai/clip-vit-large-patch14"
  13. select_hidden_state_layer = -2
  14. # (vision_config.image_size // vision_config.patch_size) ** 2
  15. image_token_len = (224//14)**2
  16. class Llava:
  17. def __init__(self, args):
  18. self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
  19. self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
  20. self.mm_projector = nn.Linear(1024, 5120)
  21. self.model = MyModel(["main", *args])
  22. def load_projection(self, path):
  23. state = torch.load(path)
  24. self.mm_projector.load_state_dict({
  25. "weight": state["model.mm_projector.weight"],
  26. "bias": state["model.mm_projector.bias"]})
  27. def chat(self, question):
  28. self.model.eval_string("user: ")
  29. self.model.eval_string(question)
  30. self.model.eval_string("\nassistant: ")
  31. return self.model.generate_with_print()
  32. def chat_with_image(self, image, question):
  33. with torch.no_grad():
  34. embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
  35. image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
  36. select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
  37. image_feature = select_hidden_state[:, 1:]
  38. embd_image = self.mm_projector(image_feature)
  39. embd_image = embd_image.cpu().numpy()[0]
  40. self.model.eval_string("user: ")
  41. self.model.eval_token(32003-2) # im_start
  42. self.model.eval_float(embd_image.T)
  43. for i in range(image_token_len-embd_image.shape[0]):
  44. self.model.eval_token(32003-3) # im_patch
  45. self.model.eval_token(32003-1) # im_end
  46. self.model.eval_string(question)
  47. self.model.eval_string("\nassistant: ")
  48. return self.model.generate_with_print()
  49. if __name__=="__main__":
  50. # model form liuhaotian/LLaVA-13b-delta-v1-1
  51. a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
  52. # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
  53. # Also here can use pytorch_model-00003-of-00003.bin directly.
  54. a.load_projection(os.path.join(
  55. os.path.dirname(__file__) ,
  56. "llava_projection.pth"))
  57. respose = a.chat_with_image(
  58. Image.open("./media/llama1-logo.png").convert('RGB'),
  59. "what is the text in the picture?")
  60. respose
  61. a.chat("what is the color of it?")