1
0

check-nmse.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import sys
  4. import os
  5. import argparse
  6. from pathlib import Path
  7. from common import get_model_name_from_env_path # type: ignore[import-not-found]
  8. def calculate_nmse(reference, test):
  9. mse = np.mean((test - reference) ** 2)
  10. ref_var = np.var(reference)
  11. if ref_var == 0:
  12. nmse = float('inf') if mse > 0 else 0.0
  13. return mse, mse, ref_var
  14. nmse = mse / ref_var
  15. return nmse, mse, ref_var
  16. def load_logits(file_path):
  17. if not os.path.exists(file_path):
  18. raise FileNotFoundError(f"File not found: {file_path}")
  19. if file_path.suffix == '.npy':
  20. return np.load(file_path)
  21. elif file_path.suffix == '.bin':
  22. return np.fromfile(file_path, dtype=np.float32)
  23. else:
  24. # Try to load as text file
  25. try:
  26. # If it has index format "0: value", extract just values
  27. data = []
  28. with open(file_path, 'r') as f:
  29. for line in f:
  30. if ':' in line:
  31. # Format: "index: value"
  32. value = float(line.split(':')[1].strip())
  33. else:
  34. # Just the value
  35. value = float(line.strip())
  36. data.append(value)
  37. return np.array(data, dtype=np.float32)
  38. except:
  39. return np.loadtxt(file_path, dtype=np.float32)
  40. def interpret_nmse(nmse):
  41. """Provide interpretation of NMSE value"""
  42. if nmse == 0:
  43. return "Perfect match", "🎉"
  44. elif nmse < 1e-6:
  45. return "Essentially identical", "✅"
  46. elif nmse < 1e-4:
  47. return "Excellent match", "✅"
  48. elif nmse < 1e-3:
  49. return "Very good match", "👍"
  50. elif nmse < 1e-2:
  51. return "Good match", "👍"
  52. elif nmse < 0.1:
  53. return "Acceptable match", "⚠️"
  54. elif nmse < 1.0:
  55. return "Poor match", "❌"
  56. else:
  57. return "Very poor match (worse than noise)", "❌"
  58. def main():
  59. parser = argparse.ArgumentParser(description='Validate model logits')
  60. parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory')
  61. args = parser.parse_args()
  62. model_name = get_model_name_from_env_path('MODEL_PATH')
  63. data_dir = Path("data")
  64. pytorch_file = data_dir / f"pytorch-{model_name}.bin"
  65. llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
  66. llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
  67. print(f"Model name: {model_name}")
  68. print(f"PyTorch logits file: {pytorch_file}")
  69. print(f"llama.cpp logits file: {llamacpp_file}")
  70. reference_file = pytorch_file
  71. test_file = llamacpp_file
  72. print("📊 NMSE Check for Model Comparison")
  73. print("=" * 50)
  74. print(f"Reference (ground truth): {reference_file}")
  75. print(f"Test (to evaluate): {test_file}")
  76. print()
  77. try:
  78. print("Loading reference logits...")
  79. reference = load_logits(reference_file)
  80. print(f" Shape: {reference.shape}, Type: {reference.dtype}")
  81. print("Loading test logits...")
  82. test = load_logits(test_file)
  83. print(f" Shape: {test.shape}, Type: {test.dtype}")
  84. # Check shapes match
  85. if reference.shape != test.shape:
  86. print(f"\n❌ Error: Shape mismatch!")
  87. print(f" Reference: {reference.shape}")
  88. print(f" Test: {test.shape}")
  89. sys.exit(1)
  90. print(f"\n✅ Shapes match: {reference.shape}")
  91. nmse, mse, ref_var = calculate_nmse(reference, test)
  92. # Additional metrics
  93. max_abs_error = np.max(np.abs(test - reference))
  94. mean_abs_error = np.mean(np.abs(test - reference))
  95. # Results
  96. print(f"\n📈 METRICS")
  97. print("=" * 30)
  98. print(f"MSE (Mean Squared Error): {mse:.6e}")
  99. print(f"Reference Variance: {ref_var:.6e}")
  100. print(f"NMSE: {nmse:.6e}")
  101. print(f"Max Absolute Error: {max_abs_error:.6f}")
  102. print(f"Mean Absolute Error: {mean_abs_error:.6f}")
  103. # NMSE in dB (common in signal processing)
  104. if nmse > 0:
  105. nmse_db = 10 * np.log10(nmse)
  106. print(f"NMSE (dB): {nmse_db:.2f} dB")
  107. # Interpretation
  108. interpretation, emoji = interpret_nmse(nmse)
  109. print(f"\n🎯 INTERPRETATION")
  110. print("=" * 30)
  111. print(f"{emoji} {interpretation}")
  112. # Detailed guidance
  113. print(f"\n📋 GUIDANCE")
  114. print("=" * 30)
  115. if nmse < 1e-3:
  116. print("✅ EXCELLENT: Your GGML conversion is working very well!")
  117. print(" The differences are negligible for practical use.")
  118. elif nmse < 1e-2:
  119. print("👍 GOOD: Your GGML conversion is working well.")
  120. print(" Small differences are likely due to precision/quantization.")
  121. elif nmse < 0.1:
  122. print("⚠️ ACCEPTABLE: Conversion is working but with some differences.")
  123. print(" Check if you're using quantization (Q4, Q8, etc.)")
  124. print(" Test generation quality to see if it's acceptable.")
  125. else:
  126. print("❌ PROBLEMATIC: Large differences detected.")
  127. print(" Check your conversion process for potential issues.")
  128. print(" Verify you're using the same model weights.")
  129. # NMSE benchmarks
  130. print(f"\n📚 NMSE BENCHMARKS")
  131. print("=" * 30)
  132. print("< 1e-6: Essentially identical")
  133. print("< 1e-4: Excellent (typical for good conversions)")
  134. print("< 1e-3: Very good")
  135. print("< 1e-2: Good (acceptable for most use cases)")
  136. print("< 0.1: Acceptable (may need verification)")
  137. print("> 1.0: Poor (worse than random)")
  138. # Exit code based on NMSE
  139. if nmse < 1e-2:
  140. print(f"\n✅ RESULT: PASS (NMSE = {nmse:.2e})")
  141. sys.exit(0)
  142. else:
  143. print(f"\n❌ RESULT: NEEDS REVIEW (NMSE = {nmse:.2e})")
  144. sys.exit(1)
  145. except Exception as e:
  146. print(f"❌ Error: {e}")
  147. sys.exit(1)
  148. if __name__ == "__main__":
  149. main()