|
@@ -101,6 +101,17 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
|
|
|
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
|
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+def read_prompt_from_file(file_path):
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
+ return f.read().strip()
|
|
|
|
|
+ except FileNotFoundError:
|
|
|
|
|
+ print(f"Error: Prompts file '{file_path}' not found")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Error reading prompts file: {e}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+
|
|
|
def main():
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings')
|
|
parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings')
|
|
|
parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model')
|
|
parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model')
|
|
@@ -108,14 +119,20 @@ def main():
|
|
|
parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file')
|
|
parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file')
|
|
|
parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true')
|
|
parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true')
|
|
|
parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt')
|
|
parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt')
|
|
|
|
|
+ parser.add_argument('--prompts-file', '-pf', help='Path to file containing prompts')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
+ if args.prompts_file:
|
|
|
|
|
+ prompt = read_prompt_from_file(args.prompts_file)
|
|
|
|
|
+ else:
|
|
|
|
|
+ prompt = args.prompt
|
|
|
|
|
+
|
|
|
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
|
|
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
|
|
|
print("=" * 70)
|
|
print("=" * 70)
|
|
|
|
|
|
|
|
# Single prompt detailed comparison
|
|
# Single prompt detailed comparison
|
|
|
- print(f"\nTesting with prompt: '{args.prompt}'")
|
|
|
|
|
|
|
+ print(f"\nTesting with prompt: '{prompt}'")
|
|
|
|
|
|
|
|
# Load the python model to get configuration information and also to load the tokenizer.
|
|
# Load the python model to get configuration information and also to load the tokenizer.
|
|
|
print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
|
|
print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
|
|
@@ -144,7 +161,7 @@ def main():
|
|
|
else:
|
|
else:
|
|
|
model = AutoModel.from_pretrained(args.model_path)
|
|
model = AutoModel.from_pretrained(args.model_path)
|
|
|
|
|
|
|
|
- encoded = tokenizer(args.prompt, return_tensors="pt")
|
|
|
|
|
|
|
+ encoded = tokenizer(prompt, return_tensors="pt")
|
|
|
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
|
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
|
|
n_tokens = len(tokens)
|
|
n_tokens = len(tokens)
|
|
|
print(f"n_tokens: {n_tokens}");
|
|
print(f"n_tokens: {n_tokens}");
|
|
@@ -155,7 +172,7 @@ def main():
|
|
|
python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size)
|
|
python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size)
|
|
|
|
|
|
|
|
# Run comparison
|
|
# Run comparison
|
|
|
- results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, args.prompt)
|
|
|
|
|
|
|
+ results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, prompt)
|
|
|
|
|
|
|
|
# Summary
|
|
# Summary
|
|
|
print(f"\n=== SUMMARY ===")
|
|
print(f"\n=== SUMMARY ===")
|