| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import argparse
- import glob
- import os
- import torch
- from safetensors import safe_open
- from safetensors.torch import save_file
- from typing import Any, ContextManager, cast
- # Function to determine if file is a SafeTensor file
- def is_safetensor_file(file_path):
- return file_path.endswith('.safetensors')
- # Unified loading function
- def load_model(file_path):
- if is_safetensor_file(file_path):
- tensors = {}
- with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
- for key in f.keys():
- tensors[key] = f.get_tensor(key).clone()
- # output shape
- print(f"{key} : {tensors[key].shape}")
- return tensors, 'safetensor'
- else:
- return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
- # Unified saving function
- def save_model(model, file_path, file_type):
- if file_type == 'safetensor':
- # safe_save(model, file_path)
- save_file(model, file_path)
- else:
- torch.save(model, file_path)
- # Adapted function to clean vision tower from checkpoint
- def clean_vision_tower_from_checkpoint(checkpoint_path):
- checkpoint, file_type = load_model(checkpoint_path)
- # file_type = 'pytorch'
- model_path = os.path.dirname(checkpoint_path)
- print(f"Searching for vision tower tensors in {checkpoint_path}")
- clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
- if len(clip_tensors) > 0:
- print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
- # Adapted for file type
- clip_path = os.path.join(model_path, "llava.clip")
- if os.path.exists(clip_path):
- print(f"Loading existing llava.clip from {clip_path}")
- existing_clip, _ = load_model(clip_path)
- else:
- print(f"Creating new llava.clip at {clip_path}")
- existing_clip = {}
- # Update existing_clip with new tensors, avoid duplicates
- for name in clip_tensors:
- simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
- print(f"Adding {simple_name} to llava.clip")
- if simple_name not in existing_clip:
- existing_clip[simple_name] = checkpoint[name]
- # Save the updated clip tensors back to llava.clip
- save_model(existing_clip, clip_path, 'pytorch')
- # Remove the tensors from the original checkpoint
- for name in clip_tensors:
- del checkpoint[name]
- checkpoint_path = checkpoint_path
- return True
- return False
- def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
- newline_checkpoint_path = None
- projector_checkpoint_path = None
- for path in checkpoint_paths:
- checkpoint, _ = load_model(path)
- if newline_criteria(checkpoint) and newline_checkpoint_path is None:
- newline_checkpoint_path = path
- if projector(checkpoint):
- projector_checkpoint_path = path
- return newline_checkpoint_path, projector_checkpoint_path
- def newline_criteria(checkpoint):
- return any(k.startswith("model.image_newline") for k in checkpoint.keys())
- def proj_criteria(checkpoint):
- return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
- # Command-line interface setup
- ap = argparse.ArgumentParser()
- ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
- ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
- args = ap.parse_args()
- if args.clean_vision_tower:
- # Generalized to handle both PyTorch and SafeTensors models
- model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
- # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
- checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
- for projector_checkpoint_path in checkpoint_paths:
- print(f"Cleaning {projector_checkpoint_path}")
- if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
- print(f"No vision tower found in {projector_checkpoint_path}")
- # we break once none is found, so far all models append them at the end
- # break
- print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
- # Now we look for the projector in the last checkpoint
- model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
- checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
- # last_checkpoint_path = checkpoint_paths[0]
- # first_checkpoint_path = checkpoint_paths[-1]
- newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
- print(f"Taking projector from {projector_checkpoint_path}")
- first_mm_tensors = []
- first_checkpoint = None
- if newline_checkpoint_path is not None:
- print(f"Taking newline from {newline_checkpoint_path}")
- first_checkpoint, file_type = load_model(newline_checkpoint_path)
- first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
- # Load the checkpoint
- mm_tensors = []
- last_checkpoint = None
- if projector_checkpoint_path is not None:
- last_checkpoint, file_type = load_model(projector_checkpoint_path)
- mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
- if len(mm_tensors) == 0:
- if last_checkpoint is not None:
- for k, v in last_checkpoint.items():
- print(k)
- print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
- print("No tensors found. Is this a LLaVA model?")
- exit()
- print(f"Found {len(mm_tensors)} tensors to extract.")
- print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
- # projector = {name: checkpoint.[name].float() for name in mm_tensors}
- projector = {}
- for name in mm_tensors:
- assert last_checkpoint is not None
- projector[name] = last_checkpoint[name].float()
- for name in first_mm_tensors:
- assert first_checkpoint is not None
- projector[name] = first_checkpoint[name].float()
- if len(projector) > 0:
- save_model(projector, f"{args.model}/llava.projector", 'pytorch')
- print("Done!")
- print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
- print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|