1
0

compare-logits.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import sys
  4. import os
  5. from pathlib import Path
  6. def quick_logits_check(pytorch_file, llamacpp_file):
  7. """Lightweight sanity check before NMSE"""
  8. try:
  9. pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
  10. llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
  11. except Exception as e:
  12. print(f"❌ NOK: Failed to load files - {e}")
  13. return False
  14. # Check shapes match
  15. if pytorch_logits.shape != llamacpp_logits.shape:
  16. print(f"❌ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
  17. return False
  18. # Calculate key metrics
  19. diff = pytorch_logits - llamacpp_logits
  20. abs_diff = np.abs(diff)
  21. max_diff = np.max(abs_diff)
  22. # Get top 10 predictions from both models
  23. pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
  24. llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
  25. print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
  26. print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
  27. print(f"Max absolute difference: {max_diff:.4f}")
  28. if max_diff > 1.0:
  29. print(f"❌ NOK: Large differences detected - max diff: {max_diff:.4f}")
  30. return False
  31. return True
  32. def main():
  33. model_path = os.getenv('MODEL_PATH')
  34. if not model_path:
  35. print("Error: MODEL_PATH environment variable not set")
  36. sys.exit(1)
  37. if not os.path.exists(model_path):
  38. print(f"Error: Model file not found: {model_path}")
  39. sys.exit(1)
  40. model_name = os.path.basename(model_path)
  41. data_dir = Path("data")
  42. pytorch_file = data_dir / f"pytorch-{model_name}.bin"
  43. llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
  44. if not pytorch_file.exists():
  45. print(f"Error: PyTorch logits file not found: {pytorch_file}")
  46. print("Please run scripts/run-org-model.sh first to generate this file.")
  47. sys.exit(1)
  48. if not llamacpp_file.exists():
  49. print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
  50. print("Please run scripts/run-converted-model.sh first to generate this file.")
  51. sys.exit(1)
  52. print("Checked all required files were found. Proceeding...\n")
  53. print("🔍 GGML Model Validation for model ", model_name)
  54. print("=" * 40)
  55. print(f"PyTorch logits : {pytorch_file}")
  56. print(f"llama.cpp logits: {llamacpp_file}")
  57. print()
  58. success = quick_logits_check(pytorch_file, llamacpp_file)
  59. # Exit with appropriate code
  60. if success:
  61. print("✅ OK: Lightweight model check successful!")
  62. print(" Ok to proceed with NMSE check...")
  63. sys.exit(0)
  64. else:
  65. print(f"❌ NOK: Top 10 predictions don't match - generation will differ")
  66. sys.exit(1)
  67. if __name__ == "__main__":
  68. main()