1
0

compare_tokens.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python3
  2. import argparse
  3. import sys
  4. from common import compare_tokens # type: ignore
  5. def parse_arguments():
  6. parser = argparse.ArgumentParser(
  7. description='Compare tokens between two models',
  8. formatter_class=argparse.RawDescriptionHelpFormatter,
  9. epilog="""
  10. Examples:
  11. %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
  12. """
  13. )
  14. parser.add_argument(
  15. 'original',
  16. help='Original model name'
  17. )
  18. parser.add_argument(
  19. 'converted',
  20. help='Converted model name'
  21. )
  22. parser.add_argument(
  23. '-s', '--suffix',
  24. default='',
  25. help='Type suffix (e.g., "-embeddings")'
  26. )
  27. parser.add_argument(
  28. '-d', '--data-dir',
  29. default='data',
  30. help='Directory containing token files (default: data)'
  31. )
  32. parser.add_argument(
  33. '-v', '--verbose',
  34. action='store_true',
  35. help='Print prompts from both models'
  36. )
  37. return parser.parse_args()
  38. def main():
  39. args = parse_arguments()
  40. if args.verbose:
  41. from pathlib import Path
  42. data_dir = Path(args.data_dir)
  43. prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
  44. prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
  45. if prompt1_file.exists():
  46. print(f"\nOriginal model prompt ({args.original}):")
  47. print(f" {prompt1_file.read_text().strip()}")
  48. if prompt2_file.exists():
  49. print(f"\nConverted model prompt ({args.converted}):")
  50. print(f" {prompt2_file.read_text().strip()}")
  51. print()
  52. result = compare_tokens(
  53. args.original,
  54. args.converted,
  55. type_suffix=args.suffix,
  56. output_dir=args.data_dir
  57. )
  58. # Enable the script to be used in shell scripts so that they can check
  59. # the exit code for success/failure.
  60. sys.exit(0 if result else 1)
  61. if __name__ == "__main__":
  62. main()