semantic_check.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import argparse
  4. import os
  5. import importlib
  6. from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
  7. unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
  8. def cosine_similarity(a, b=None):
  9. a = np.asarray(a)
  10. if b is None:
  11. b = a
  12. else:
  13. b = np.asarray(b)
  14. if a.ndim == 1:
  15. a = a.reshape(1, -1)
  16. if b.ndim == 1:
  17. b = b.reshape(1, -1)
  18. a_norms = np.linalg.norm(a, axis=1, keepdims=True)
  19. b_norms = np.linalg.norm(b, axis=1, keepdims=True)
  20. a_norms = np.where(a_norms == 0, 1e-8, a_norms)
  21. b_norms = np.where(b_norms == 0, 1e-8, b_norms)
  22. a_normalized = a / a_norms
  23. b_normalized = b / b_norms
  24. # Compute cosine similarity
  25. return np.dot(a_normalized, b_normalized.T)
  26. def load_embeddings_from_file(filename, n_tokens, n_embd):
  27. embeddings = np.fromfile(filename, dtype=np.float32)
  28. # Check if this is pooled (single embedding) or per-token embeddings
  29. if len(embeddings) == n_embd:
  30. return embeddings.reshape(1, n_embd)
  31. else:
  32. return embeddings.reshape(n_tokens, n_embd)
  33. def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
  34. np.set_printoptions(suppress=True, precision=6)
  35. print("pytorch embeddings:");
  36. print(python_emb)
  37. print("llama.cpp embeddings:");
  38. print(cpp_emb)
  39. print(f"\n=== Prompt: '{prompt}' ===")
  40. print(f"Tokens: {tokens}")
  41. print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
  42. n_tokens = len(tokens)
  43. is_pooled = python_emb.shape[0] == 1
  44. if is_pooled:
  45. print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
  46. # 1. Direct embedding comparison for pooled embeddings
  47. print(f"\n1. Raw Embedding Magnitude Comparison:")
  48. py_mag = np.linalg.norm(python_emb[0])
  49. cpp_mag = np.linalg.norm(cpp_emb[0])
  50. ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
  51. print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
  52. # 2. Cross-model similarity for pooled embeddings
  53. print(f"\n2. Cross-Model Pooled Embedding Similarity:")
  54. sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
  55. print(f" Cosine similarity: {sim:.6f}")
  56. return {
  57. 'cross_model_similarities': [sim],
  58. 'similarity_matrix_diff': np.array([[0.0]]),
  59. 'max_diff': 0.0,
  60. 'mean_diff': 0.0,
  61. 'rms_diff': 0.0
  62. }
  63. else:
  64. # Original per-token comparison logic
  65. # 1. Direct embedding comparison
  66. print(f"\n1. Raw Embedding Magnitude Comparison:")
  67. # Check if the distance of each token embedding from the origin and compare
  68. # if the vectors are on the same "sphere". This does not tell us about
  69. # direction (meaning of the token embedding), just magnitude.
  70. for i in range(n_tokens):
  71. py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
  72. cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
  73. ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
  74. print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
  75. # 2. Cosine similarity between tokens within each model
  76. # Here we check the direction of token embeddings to see if the have the
  77. # same meaning (similarity). This is done by calculating cosine similarity
  78. # of a pair of token embeddings within each model.
  79. print(f"\n2. Within-Model Token Similarities:")
  80. print(" Python model:")
  81. for i in range(n_tokens):
  82. for j in range(i+1, n_tokens):
  83. sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
  84. print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
  85. print(" llama.cpp model:")
  86. for i in range(n_tokens):
  87. for j in range(i+1, n_tokens):
  88. sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
  89. print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
  90. # 3. Cross-model similarity (same token position)
  91. print(f"\n3. Cross-Model Same-Token Similarities:")
  92. for i in range(n_tokens):
  93. sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
  94. print(f" Token {i} ({tokens[i]}): {sim:.4f}")
  95. # 4. Similarity matrix comparison
  96. print(f"\n4. Similarity Matrix Differences:")
  97. py_sim_matrix = cosine_similarity(python_emb)
  98. cpp_sim_matrix = cosine_similarity(cpp_emb)
  99. diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
  100. print(f" Max difference: {np.max(diff_matrix):.4f}")
  101. print(f" Mean difference: {np.mean(diff_matrix):.4f}")
  102. print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
  103. return {
  104. 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
  105. 'similarity_matrix_diff': diff_matrix,
  106. 'max_diff': np.max(diff_matrix),
  107. 'mean_diff': np.mean(diff_matrix),
  108. 'rms_diff': np.sqrt(np.mean(diff_matrix**2))
  109. }
  110. def read_prompt_from_file(file_path):
  111. try:
  112. with open(file_path, 'r', encoding='utf-8') as f:
  113. return f.read().strip()
  114. except FileNotFoundError:
  115. print(f"Error: Prompts file '{file_path}' not found")
  116. exit(1)
  117. except Exception as e:
  118. print(f"Error reading prompts file: {e}")
  119. exit(1)
  120. def main():
  121. parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings')
  122. parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model')
  123. parser.add_argument('--python-embeddings', '-pe', help='Path to pytorch embeddings "logits" binary file')
  124. parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file')
  125. parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true')
  126. parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt')
  127. parser.add_argument('--prompts-file', '-pf', help='Path to file containing prompts')
  128. args = parser.parse_args()
  129. if args.prompts_file:
  130. prompt = read_prompt_from_file(args.prompts_file)
  131. else:
  132. prompt = args.prompt
  133. print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
  134. print("=" * 70)
  135. # Single prompt detailed comparison
  136. print(f"\nTesting with prompt: '{prompt}'")
  137. # Load the python model to get configuration information and also to load the tokenizer.
  138. print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
  139. tokenizer = AutoTokenizer.from_pretrained(args.model_path)
  140. config = AutoConfig.from_pretrained(args.model_path)
  141. if unreleased_model_name:
  142. model_name_lower = unreleased_model_name.lower()
  143. unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  144. if args.causal:
  145. class_name = f"{unreleased_model_name}ForCausalLM"
  146. else:
  147. class_name = f"{unreleased_model_name}Model"
  148. print(f"Model class: {class_name}")
  149. print(f"Importing unreleased model module: {unreleased_module_path}")
  150. try:
  151. model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
  152. model = model_class.from_pretrained(args.model_path)
  153. except (ImportError, AttributeError) as e:
  154. print(f"Failed to import or load model: {e}")
  155. exit(1)
  156. else:
  157. if args.causal:
  158. model = AutoModelForCausalLM.from_pretrained(args.model_path)
  159. else:
  160. model = AutoModel.from_pretrained(args.model_path)
  161. encoded = tokenizer(prompt, return_tensors="pt")
  162. tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
  163. n_tokens = len(tokens)
  164. print(f"n_tokens: {n_tokens}");
  165. print(f"hidden_size: {model.config.hidden_size}")
  166. # Load binary embeddings from data directory.
  167. llamacpp_embeddings = load_embeddings_from_file(args.cpp_embeddings, n_tokens, model.config.hidden_size)
  168. python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size)
  169. # Run comparison
  170. results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, prompt)
  171. # Summary
  172. print(f"\n=== SUMMARY ===")
  173. avg_cross_sim = np.mean(results['cross_model_similarities'])
  174. print(f"Average cross-model similarity: {avg_cross_sim:.4f}")
  175. print(f"Similarity matrix RMS difference: {results['rms_diff']:.4f}")
  176. # Quality assessment
  177. if avg_cross_sim > 0.95:
  178. print("✅ EXCELLENT: Models are highly similar")
  179. elif avg_cross_sim > 0.90:
  180. print("✅ VERY GOOD: Models are very similar")
  181. elif avg_cross_sim > 0.80:
  182. print("⚠️ GOOD: Models are reasonably similar")
  183. elif avg_cross_sim > 0.70:
  184. print("⚠️ FAIR: Models have some differences")
  185. else:
  186. print("❌ POOR: Models are significantly different")
  187. if __name__ == "__main__":
  188. main()