1
0

inspect-org-model.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import json
  5. from safetensors import safe_open
  6. from collections import defaultdict
  7. parser = argparse.ArgumentParser(description='Process model with specified path')
  8. parser.add_argument('--model-path', '-m', help='Path to the model')
  9. args = parser.parse_args()
  10. model_path = os.environ.get('MODEL_PATH', args.model_path)
  11. if model_path is None:
  12. parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
  13. # Check if there's an index file (multi-file model)
  14. index_path = os.path.join(model_path, "model.safetensors.index.json")
  15. single_file_path = os.path.join(model_path, "model.safetensors")
  16. if os.path.exists(index_path):
  17. # Multi-file model
  18. print("Multi-file model detected")
  19. with open(index_path, 'r') as f:
  20. index_data = json.load(f)
  21. # Get the weight map (tensor_name -> file_name)
  22. weight_map = index_data.get("weight_map", {})
  23. # Group tensors by file for efficient processing
  24. file_tensors = defaultdict(list)
  25. for tensor_name, file_name in weight_map.items():
  26. file_tensors[file_name].append(tensor_name)
  27. print("Tensors in model:")
  28. # Process each shard file
  29. for file_name, tensor_names in file_tensors.items():
  30. file_path = os.path.join(model_path, file_name)
  31. print(f"\n--- From {file_name} ---")
  32. with safe_open(file_path, framework="pt") as f:
  33. for tensor_name in sorted(tensor_names):
  34. tensor = f.get_tensor(tensor_name)
  35. print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
  36. elif os.path.exists(single_file_path):
  37. # Single file model (original behavior)
  38. print("Single-file model detected")
  39. with safe_open(single_file_path, framework="pt") as f:
  40. keys = f.keys()
  41. print("Tensors in model:")
  42. for key in sorted(keys):
  43. tensor = f.get_tensor(key)
  44. print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
  45. else:
  46. print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
  47. print("Available files:")
  48. if os.path.exists(model_path):
  49. for item in sorted(os.listdir(model_path)):
  50. print(f" {item}")
  51. else:
  52. print(f" Directory {model_path} does not exist")
  53. exit(1)