compare-embeddings-logits.sh 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env bash
  2. set -e
  3. # Parse command line arguments
  4. MODEL_PATH=""
  5. MODEL_NAME=""
  6. PROMPTS_FILE=""
  7. # First argument is always model path
  8. if [ $# -gt 0 ] && [[ "$1" != --* ]]; then
  9. MODEL_PATH="$1"
  10. shift
  11. fi
  12. # Parse remaining arguments
  13. while [[ $# -gt 0 ]]; do
  14. case $1 in
  15. --prompts-file|-pf)
  16. PROMPTS_FILE="$2"
  17. shift 2
  18. ;;
  19. *)
  20. # If MODEL_NAME not set and this isn't a flag, use as model name
  21. if [ -z "$MODEL_NAME" ] && [[ "$1" != --* ]]; then
  22. MODEL_NAME="$1"
  23. fi
  24. shift
  25. ;;
  26. esac
  27. done
  28. # Set defaults
  29. MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}"
  30. MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
  31. CONVERTED_MODEL_PATH="${CONVERTED_EMBEDDING_PATH:-"$CONVERTED_EMBEDDING_MODEL"}"
  32. CONVERTED_MODEL_NAME="${CONVERTED_MODEL_NAME:-$(basename "$CONVERTED_MODEL_PATH" .gguf)}"
  33. if [ -t 0 ]; then
  34. CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin"
  35. else
  36. # Process piped JSON data and convert to binary (matching logits.cpp format)
  37. TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn)
  38. python3 -c "
  39. import json
  40. import sys
  41. import struct
  42. data = json.load(sys.stdin)
  43. # Flatten all embeddings completely
  44. flattened = []
  45. for item in data:
  46. embedding = item['embedding']
  47. for token_embedding in embedding:
  48. flattened.extend(token_embedding)
  49. print(f'Total embedding values: {len(flattened)}', file=sys.stderr)
  50. # Write as binary floats - matches logitc.cpp fwrite format
  51. with open('$TEMP_FILE', 'wb') as f:
  52. for value in flattened:
  53. f.write(struct.pack('f', value))
  54. "
  55. CPP_EMBEDDINGS="$TEMP_FILE"
  56. trap "rm -f $TEMP_FILE" EXIT
  57. fi
  58. # Build the semantic_check.py command
  59. SEMANTIC_CMD="python scripts/utils/semantic_check.py --model-path $MODEL_PATH \
  60. --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \
  61. --cpp-embeddings $CPP_EMBEDDINGS"
  62. # Add prompts file if specified, otherwise use default prompt
  63. if [ -n "$PROMPTS_FILE" ]; then
  64. SEMANTIC_CMD="$SEMANTIC_CMD --prompts-file \"$PROMPTS_FILE\""
  65. else
  66. SEMANTIC_CMD="$SEMANTIC_CMD --prompt \"Hello world today\""
  67. fi
  68. # Execute the command
  69. eval $SEMANTIC_CMD