semantic_check.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import argparse
  4. import os
  5. import importlib
  6. from pathlib import Path
  7. from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
  8. from common import compare_tokens, exit_with_warning # type: ignore[import-not-found]
  9. unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
  10. def cosine_similarity(a, b=None):
  11. a = np.asarray(a)
  12. if b is None:
  13. b = a
  14. else:
  15. b = np.asarray(b)
  16. if a.ndim == 1:
  17. a = a.reshape(1, -1)
  18. if b.ndim == 1:
  19. b = b.reshape(1, -1)
  20. a_norms = np.linalg.norm(a, axis=1, keepdims=True)
  21. b_norms = np.linalg.norm(b, axis=1, keepdims=True)
  22. a_norms = np.where(a_norms == 0, 1e-8, a_norms)
  23. b_norms = np.where(b_norms == 0, 1e-8, b_norms)
  24. a_normalized = a / a_norms
  25. b_normalized = b / b_norms
  26. # Compute cosine similarity
  27. return np.dot(a_normalized, b_normalized.T)
  28. def load_embeddings_from_file(filename, n_tokens, n_embd):
  29. embeddings = np.fromfile(filename, dtype=np.float32)
  30. # Check if this is pooled (single embedding) or per-token embeddings
  31. if len(embeddings) == n_embd:
  32. return embeddings.reshape(1, n_embd)
  33. else:
  34. return embeddings.reshape(n_tokens, n_embd)
  35. def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
  36. np.set_printoptions(suppress=True, precision=6)
  37. print("pytorch embeddings:");
  38. print(python_emb)
  39. print("llama.cpp embeddings:");
  40. print(cpp_emb)
  41. print(f"\n=== Prompt: '{prompt}' ===")
  42. print(f"Tokens: {tokens}")
  43. print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
  44. n_tokens = len(tokens)
  45. is_pooled = python_emb.shape[0] == 1
  46. if is_pooled:
  47. print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
  48. # 1. Direct embedding comparison for pooled embeddings
  49. print(f"\n1. Raw Embedding Magnitude Comparison:")
  50. py_mag = np.linalg.norm(python_emb[0])
  51. cpp_mag = np.linalg.norm(cpp_emb[0])
  52. ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
  53. print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
  54. # 2. Cross-model similarity for pooled embeddings
  55. print(f"\n2. Cross-Model Pooled Embedding Similarity:")
  56. sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
  57. print(f" Cosine similarity: {sim:.6f}")
  58. return {
  59. 'cross_model_similarities': [sim],
  60. 'similarity_matrix_diff': np.array([[0.0]]),
  61. 'max_diff': 0.0,
  62. 'mean_diff': 0.0,
  63. 'rms_diff': 0.0
  64. }
  65. else:
  66. # Original per-token comparison logic
  67. # 1. Direct embedding comparison
  68. print(f"\n1. Raw Embedding Magnitude Comparison:")
  69. # Check if the distance of each token embedding from the origin and compare
  70. # if the vectors are on the same "sphere". This does not tell us about
  71. # direction (meaning of the token embedding), just magnitude.
  72. for i in range(n_tokens):
  73. py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
  74. cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
  75. ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
  76. print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
  77. # 2. Cosine similarity between tokens within each model
  78. # Here we check the direction of token embeddings to see if the have the
  79. # same meaning (similarity). This is done by calculating cosine similarity
  80. # of a pair of token embeddings within each model.
  81. print(f"\n2. Within-Model Token Similarities:")
  82. print(" Python model:")
  83. for i in range(n_tokens):
  84. for j in range(i+1, n_tokens):
  85. sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
  86. print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
  87. print(" llama.cpp model:")
  88. for i in range(n_tokens):
  89. for j in range(i+1, n_tokens):
  90. sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
  91. print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
  92. # 3. Cross-model similarity (same token position)
  93. print(f"\n3. Cross-Model Same-Token Similarities:")
  94. for i in range(n_tokens):
  95. sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
  96. print(f" Token {i} ({tokens[i]}): {sim:.4f}")
  97. # 4. Similarity matrix comparison
  98. print(f"\n4. Similarity Matrix Differences:")
  99. py_sim_matrix = cosine_similarity(python_emb)
  100. cpp_sim_matrix = cosine_similarity(cpp_emb)
  101. diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
  102. print(f" Max difference: {np.max(diff_matrix):.4f}")
  103. print(f" Mean difference: {np.mean(diff_matrix):.4f}")
  104. print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
  105. return {
  106. 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
  107. 'similarity_matrix_diff': diff_matrix,
  108. 'max_diff': np.max(diff_matrix),
  109. 'mean_diff': np.mean(diff_matrix),
  110. 'rms_diff': np.sqrt(np.mean(diff_matrix**2))
  111. }
  112. def read_prompt_from_file(file_path):
  113. try:
  114. with open(file_path, 'r', encoding='utf-8') as f:
  115. return f.read().strip()
  116. except FileNotFoundError:
  117. print(f"Error: Prompts file '{file_path}' not found")
  118. exit(1)
  119. except Exception as e:
  120. print(f"Error reading prompts file: {e}")
  121. exit(1)
  122. def main():
  123. parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings')
  124. parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model')
  125. parser.add_argument('--python-embeddings', '-pe', help='Path to pytorch embeddings "logits" binary file')
  126. parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file')
  127. parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true')
  128. parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt')
  129. parser.add_argument('--prompts-file', '-pf', help='Path to file containing prompts')
  130. args = parser.parse_args()
  131. if args.prompts_file:
  132. prompt = read_prompt_from_file(args.prompts_file)
  133. else:
  134. prompt = args.prompt
  135. python_emb_path = Path(args.python_embeddings)
  136. cpp_emb_path = Path(args.cpp_embeddings)
  137. # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name")
  138. python_model_name = python_emb_path.stem.replace("-embeddings", "")
  139. cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "")
  140. print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
  141. print("=" * 70)
  142. # First verify tokens match before comparing embeddings
  143. print("\n🔍 Token Comparison Check")
  144. print("=" * 70)
  145. data_dir = python_emb_path.parent
  146. if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)):
  147. exit_with_warning("\n❌ Token mismatch detected", args.model_path)
  148. print()
  149. # Single prompt detailed comparison
  150. print(f"\nTesting with prompt: '{prompt}'")
  151. # Load the python model to get configuration information and also to load the tokenizer.
  152. print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
  153. tokenizer = AutoTokenizer.from_pretrained(args.model_path)
  154. config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
  155. if unreleased_model_name:
  156. model_name_lower = unreleased_model_name.lower()
  157. unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  158. if args.causal:
  159. class_name = f"{unreleased_model_name}ForCausalLM"
  160. else:
  161. class_name = f"{unreleased_model_name}Model"
  162. print(f"Model class: {class_name}")
  163. print(f"Importing unreleased model module: {unreleased_module_path}")
  164. try:
  165. model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
  166. model = model_class.from_pretrained(args.model_path)
  167. except (ImportError, AttributeError) as e:
  168. print(f"Failed to import or load model: {e}")
  169. exit(1)
  170. else:
  171. if args.causal:
  172. model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
  173. else:
  174. model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
  175. encoded = tokenizer(prompt, return_tensors="pt")
  176. tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
  177. n_tokens = len(tokens)
  178. print(f"n_tokens: {n_tokens}");
  179. print(f"hidden_size: {model.config.hidden_size}")
  180. # Load binary embeddings from data directory.
  181. llamacpp_embeddings = load_embeddings_from_file(args.cpp_embeddings, n_tokens, model.config.hidden_size)
  182. python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size)
  183. # Run comparison
  184. results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, prompt)
  185. # Summary
  186. print(f"\n=== SUMMARY ===")
  187. avg_cross_sim = np.mean(results['cross_model_similarities'])
  188. print(f"Average cross-model similarity: {avg_cross_sim:.4f}")
  189. print(f"Similarity matrix RMS difference: {results['rms_diff']:.4f}")
  190. # Quality assessment
  191. if avg_cross_sim > 0.95:
  192. print("✅ EXCELLENT: Models are highly similar")
  193. elif avg_cross_sim > 0.90:
  194. print("✅ VERY GOOD: Models are very similar")
  195. elif avg_cross_sim > 0.80:
  196. print("⚠️ GOOD: Models are reasonably similar")
  197. elif avg_cross_sim > 0.70:
  198. print("⚠️ FAIR: Models have some differences")
  199. else:
  200. exit_with_warning("❌ POOR: Models are significantly different", args.model_path)
  201. if __name__ == "__main__":
  202. main()