1
0

compare-logits.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #!/usr/bin/env python3
  2. import sys
  3. import numpy as np
  4. from pathlib import Path
  5. # Add utils directory to path for direct script execution
  6. sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
  7. from common import get_model_name_from_env_path # type: ignore[import-not-found]
  8. def quick_logits_check(pytorch_file, llamacpp_file):
  9. """Lightweight sanity check before NMSE"""
  10. try:
  11. pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
  12. llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
  13. except Exception as e:
  14. print(f"❌ NOK: Failed to load files - {e}")
  15. return False
  16. # Check shapes match
  17. if pytorch_logits.shape != llamacpp_logits.shape:
  18. print(f"❌ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
  19. return False
  20. # Calculate key metrics
  21. diff = pytorch_logits - llamacpp_logits
  22. abs_diff = np.abs(diff)
  23. max_diff = np.max(abs_diff)
  24. # Get top 10 predictions from both models
  25. pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
  26. llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
  27. print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
  28. print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
  29. print(f"Max absolute difference: {max_diff:.4f}")
  30. return True
  31. def main():
  32. model_name = get_model_name_from_env_path('MODEL_PATH')
  33. data_dir = Path("data")
  34. pytorch_file = data_dir / f"pytorch-{model_name}.bin"
  35. llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
  36. print(f"Using converted model: {llamacpp_model_name}")
  37. llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
  38. if not pytorch_file.exists():
  39. print(f"Error: PyTorch logits file not found: {pytorch_file}")
  40. print("Please run scripts/run-org-model.sh first to generate this file.")
  41. sys.exit(1)
  42. if not llamacpp_file.exists():
  43. print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
  44. print("Please run scripts/run-converted-model.sh first to generate this file.")
  45. sys.exit(1)
  46. print("Checked all required files were found. Proceeding...\n")
  47. print("🔍 GGML Model Validation for model ", model_name)
  48. print("=" * 40)
  49. print(f"PyTorch logits : {pytorch_file}")
  50. print(f"llama.cpp logits: {llamacpp_file}")
  51. print()
  52. success = quick_logits_check(pytorch_file, llamacpp_file)
  53. # Exit with appropriate code
  54. if success:
  55. print("✅ OK: Lightweight model check successful!")
  56. print(" Ok to proceed with NMSE check...")
  57. sys.exit(0)
  58. else:
  59. print(f"❌ NOK: Top 10 predictions don't match - generation will differ")
  60. sys.exit(1)
  61. if __name__ == "__main__":
  62. main()