| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- #!/usr/bin/env python3
- import numpy as np
- import sys
- import os
- from pathlib import Path
- def quick_logits_check(pytorch_file, llamacpp_file):
- """Lightweight sanity check before NMSE"""
- try:
- pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
- llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
- except Exception as e:
- print(f"❌ NOK: Failed to load files - {e}")
- return False
- # Check shapes match
- if pytorch_logits.shape != llamacpp_logits.shape:
- print(f"❌ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
- return False
- # Calculate key metrics
- diff = pytorch_logits - llamacpp_logits
- abs_diff = np.abs(diff)
- max_diff = np.max(abs_diff)
- # Get top 10 predictions from both models
- pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
- llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
- print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
- print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
- print(f"Max absolute difference: {max_diff:.4f}")
- if max_diff > 1.0:
- print(f"❌ NOK: Large differences detected - max diff: {max_diff:.4f}")
- return False
- return True
- def main():
- model_path = os.getenv('MODEL_PATH')
- if not model_path:
- print("Error: MODEL_PATH environment variable not set")
- sys.exit(1)
- if not os.path.exists(model_path):
- print(f"Error: Model file not found: {model_path}")
- sys.exit(1)
- model_name = os.path.basename(model_path)
- data_dir = Path("data")
- pytorch_file = data_dir / f"pytorch-{model_name}.bin"
- llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
- if not pytorch_file.exists():
- print(f"Error: PyTorch logits file not found: {pytorch_file}")
- print("Please run scripts/run-org-model.sh first to generate this file.")
- sys.exit(1)
- if not llamacpp_file.exists():
- print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
- print("Please run scripts/run-converted-model.sh first to generate this file.")
- sys.exit(1)
- print("Checked all required files were found. Proceeding...\n")
- print("🔍 GGML Model Validation for model ", model_name)
- print("=" * 40)
- print(f"PyTorch logits : {pytorch_file}")
- print(f"llama.cpp logits: {llamacpp_file}")
- print()
- success = quick_logits_check(pytorch_file, llamacpp_file)
- # Exit with appropriate code
- if success:
- print("✅ OK: Lightweight model check successful!")
- print(" Ok to proceed with NMSE check...")
- sys.exit(0)
- else:
- print(f"❌ NOK: Top 10 predictions don't match - generation will differ")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|