llava_surgery_v2.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import argparse
  2. import glob
  3. import os
  4. import torch
  5. from safetensors import safe_open
  6. from safetensors.torch import save_file
  7. from typing import Any, ContextManager, cast
  8. # Function to determine if file is a SafeTensor file
  9. def is_safetensor_file(file_path):
  10. return file_path.endswith('.safetensors')
  11. # Unified loading function
  12. def load_model(file_path):
  13. if is_safetensor_file(file_path):
  14. tensors = {}
  15. with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
  16. for key in f.keys():
  17. tensors[key] = f.get_tensor(key).clone()
  18. # output shape
  19. print(f"{key} : {tensors[key].shape}")
  20. return tensors, 'safetensor'
  21. else:
  22. return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
  23. # Unified saving function
  24. def save_model(model, file_path, file_type):
  25. if file_type == 'safetensor':
  26. # safe_save(model, file_path)
  27. save_file(model, file_path)
  28. else:
  29. torch.save(model, file_path)
  30. # Helpers to match weight names from specific components or
  31. # determine if a saved shard contains that component
  32. def is_vision_tower(weight_name):
  33. return (
  34. weight_name.startswith("model.vision_tower") or
  35. weight_name.startswith("vit.") or
  36. weight_name.startswith("vision_tower")
  37. )
  38. def is_newline(weight_name):
  39. return (
  40. weight_name.startswith("model.image_newline") or
  41. weight_name.startswith("image_newline")
  42. )
  43. def is_mm_projector(weight_name):
  44. return (
  45. weight_name.startswith("model.mm_projector") or
  46. weight_name.startswith("vision_proj.") or
  47. weight_name.startswith("multi_modal_projector")
  48. )
  49. def newline_criteria(checkpoint):
  50. return any(is_newline(k) for k in checkpoint.keys())
  51. def proj_criteria(checkpoint):
  52. return any(is_mm_projector(k) for k in checkpoint.keys())
  53. # Adapted function to clean vision tower from checkpoint
  54. def clean_vision_tower_from_checkpoint(checkpoint_path):
  55. checkpoint, file_type = load_model(checkpoint_path)
  56. # file_type = 'pytorch'
  57. model_path = os.path.dirname(checkpoint_path)
  58. print(f"Searching for vision tower tensors in {checkpoint_path}")
  59. clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
  60. if len(clip_tensors) > 0:
  61. print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
  62. # Adapted for file type
  63. clip_path = os.path.join(model_path, "llava.clip")
  64. if os.path.exists(clip_path):
  65. print(f"Loading existing llava.clip from {clip_path}")
  66. existing_clip, _ = load_model(clip_path)
  67. else:
  68. print(f"Creating new llava.clip at {clip_path}")
  69. existing_clip = {}
  70. # Update existing_clip with new tensors, avoid duplicates
  71. for name in clip_tensors:
  72. simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
  73. print(f"Adding {simple_name} to llava.clip")
  74. if simple_name not in existing_clip:
  75. existing_clip[simple_name] = checkpoint[name]
  76. # Save the updated clip tensors back to llava.clip
  77. save_model(existing_clip, clip_path, 'pytorch')
  78. # Remove the tensors from the original checkpoint
  79. for name in clip_tensors:
  80. del checkpoint[name]
  81. checkpoint_path = checkpoint_path
  82. return True
  83. return False
  84. def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
  85. newline_checkpoint_path = None
  86. projector_checkpoint_path = None
  87. for path in checkpoint_paths:
  88. checkpoint, _ = load_model(path)
  89. if newline_criteria(checkpoint) and newline_checkpoint_path is None:
  90. newline_checkpoint_path = path
  91. if projector(checkpoint):
  92. projector_checkpoint_path = path
  93. return newline_checkpoint_path, projector_checkpoint_path
  94. # Command-line interface setup
  95. ap = argparse.ArgumentParser()
  96. ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
  97. ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
  98. args = ap.parse_args()
  99. if args.clean_vision_tower:
  100. # Generalized to handle both PyTorch and SafeTensors models
  101. model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
  102. # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
  103. checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
  104. for projector_checkpoint_path in checkpoint_paths:
  105. print(f"Cleaning {projector_checkpoint_path}")
  106. if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
  107. print(f"No vision tower found in {projector_checkpoint_path}")
  108. # we break once none is found, so far all models append them at the end
  109. # break
  110. print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
  111. # Now we look for the projector in the last checkpoint
  112. model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
  113. checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
  114. # last_checkpoint_path = checkpoint_paths[0]
  115. # first_checkpoint_path = checkpoint_paths[-1]
  116. newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
  117. print(f"Taking projector from {projector_checkpoint_path}")
  118. first_mm_tensors = []
  119. first_checkpoint = None
  120. if newline_checkpoint_path is not None:
  121. print(f"Taking newline from {newline_checkpoint_path}")
  122. first_checkpoint, file_type = load_model(newline_checkpoint_path)
  123. first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
  124. # Load the checkpoint
  125. mm_tensors = []
  126. last_checkpoint = None
  127. if projector_checkpoint_path is not None:
  128. last_checkpoint, file_type = load_model(projector_checkpoint_path)
  129. mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
  130. if len(mm_tensors) == 0:
  131. if last_checkpoint is not None:
  132. for k, v in last_checkpoint.items():
  133. print(k)
  134. print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
  135. print("No tensors found. Is this a LLaVA model?")
  136. exit()
  137. print(f"Found {len(mm_tensors)} tensors to extract.")
  138. print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
  139. # projector = {name: checkpoint.[name].float() for name in mm_tensors}
  140. projector = {}
  141. for name in mm_tensors:
  142. assert last_checkpoint is not None
  143. projector[name] = last_checkpoint[name].float()
  144. for name in first_mm_tensors:
  145. assert first_checkpoint is not None
  146. projector[name] = first_checkpoint[name].float()
  147. if len(projector) > 0:
  148. save_model(projector, f"{args.model}/llava.projector", 'pytorch')
  149. print("Done!")
  150. print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
  151. print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")