compare-embeddings-logits.sh 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env bash
  2. set -e
  3. # Parse command line arguments
  4. MODEL_PATH=""
  5. MODEL_NAME=""
  6. PROMPTS_FILE=""
  7. # First argument is always model path
  8. if [ $# -gt 0 ] && [[ "$1" != --* ]]; then
  9. MODEL_PATH="$1"
  10. shift
  11. fi
  12. # Parse remaining arguments
  13. while [[ $# -gt 0 ]]; do
  14. case $1 in
  15. --prompts-file|-pf)
  16. PROMPTS_FILE="$2"
  17. shift 2
  18. ;;
  19. *)
  20. # If MODEL_NAME not set and this isn't a flag, use as model name
  21. if [ -z "$MODEL_NAME" ] && [[ "$1" != --* ]]; then
  22. MODEL_NAME="$1"
  23. fi
  24. shift
  25. ;;
  26. esac
  27. done
  28. # Set defaults
  29. MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}"
  30. MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
  31. if [ -t 0 ]; then
  32. CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin"
  33. else
  34. # Process piped JSON data and convert to binary (matching logits.cpp format)
  35. TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn)
  36. python3 -c "
  37. import json
  38. import sys
  39. import struct
  40. data = json.load(sys.stdin)
  41. # Flatten all embeddings completely
  42. flattened = []
  43. for item in data:
  44. embedding = item['embedding']
  45. for token_embedding in embedding:
  46. flattened.extend(token_embedding)
  47. print(f'Total embedding values: {len(flattened)}', file=sys.stderr)
  48. # Write as binary floats - matches logitc.cpp fwrite format
  49. with open('$TEMP_FILE', 'wb') as f:
  50. for value in flattened:
  51. f.write(struct.pack('f', value))
  52. "
  53. CPP_EMBEDDINGS="$TEMP_FILE"
  54. trap "rm -f $TEMP_FILE" EXIT
  55. fi
  56. # Build the semantic_check.py command
  57. SEMANTIC_CMD="python scripts/utils/semantic_check.py --model-path $MODEL_PATH \
  58. --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \
  59. --cpp-embeddings $CPP_EMBEDDINGS"
  60. # Add prompts file if specified, otherwise use default prompt
  61. if [ -n "$PROMPTS_FILE" ]; then
  62. SEMANTIC_CMD="$SEMANTIC_CMD --prompts-file \"$PROMPTS_FILE\""
  63. else
  64. SEMANTIC_CMD="$SEMANTIC_CMD --prompt \"Hello world today\""
  65. fi
  66. # Execute the command
  67. eval $SEMANTIC_CMD