quantize.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python3
  2. """Script to execute the "quantize" script on a given set of models."""
  3. import subprocess
  4. import argparse
  5. import glob
  6. import sys
  7. import os
  8. def main():
  9. """Update the quantize binary name depending on the platform and parse
  10. the command line arguments and execute the script.
  11. """
  12. if "linux" in sys.platform or "darwin" in sys.platform:
  13. quantize_script_binary = "quantize"
  14. elif "win32" in sys.platform or "cygwin" in sys.platform:
  15. quantize_script_binary = "quantize.exe"
  16. else:
  17. print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
  18. quantize_script_binary = "quantize"
  19. parser = argparse.ArgumentParser(
  20. prog='python3 quantize.py',
  21. description='This script quantizes the given models by applying the '
  22. f'"{quantize_script_binary}" script on them.'
  23. )
  24. parser.add_argument(
  25. 'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
  26. help='The models to quantize.'
  27. )
  28. parser.add_argument(
  29. '-r', '--remove-16', action='store_true', dest='remove_f16',
  30. help='Remove the f16 model after quantizing it.'
  31. )
  32. parser.add_argument(
  33. '-m', '--models-path', dest='models_path',
  34. default=os.path.join(os.getcwd(), "models"),
  35. help='Specify the directory where the models are located.'
  36. )
  37. parser.add_argument(
  38. '-q', '--quantize-script-path', dest='quantize_script_path',
  39. default=os.path.join(os.getcwd(), quantize_script_binary),
  40. help='Specify the path to the "quantize" script.'
  41. )
  42. # TODO: Revise this code
  43. # parser.add_argument(
  44. # '-t', '--threads', dest='threads', type='int',
  45. # default=os.cpu_count(),
  46. # help='Specify the number of threads to use to quantize many models at '
  47. # 'once. Defaults to os.cpu_count().'
  48. # )
  49. args = parser.parse_args()
  50. if not os.path.isfile(args.quantize_script_path):
  51. print(
  52. f'The "{quantize_script_binary}" script was not found in the '
  53. "current location.\nIf you want to use it from another location, "
  54. "set the --quantize-script-path argument from the command line."
  55. )
  56. sys.exit(1)
  57. for model in args.models:
  58. # The model is separated in various parts
  59. # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
  60. f16_model_path_base = os.path.join(
  61. args.models_path, model, "ggml-model-f16.bin"
  62. )
  63. f16_model_parts_paths = map(
  64. lambda filename: os.path.join(f16_model_path_base, filename),
  65. glob.glob(f"{f16_model_path_base}*")
  66. )
  67. for f16_model_part_path in f16_model_parts_paths:
  68. if not os.path.isfile(f16_model_part_path):
  69. print(
  70. f"The f16 model {os.path.basename(f16_model_part_path)} "
  71. f"was not found in {args.models_path}{os.path.sep}{model}"
  72. ". If you want to use it from another location, set the "
  73. "--models-path argument from the command line."
  74. )
  75. sys.exit(1)
  76. __run_quantize_script(
  77. args.quantize_script_path, f16_model_part_path
  78. )
  79. if args.remove_f16:
  80. os.remove(f16_model_part_path)
  81. # This was extracted to a top-level function for parallelization, if
  82. # implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
  83. def __run_quantize_script(script_path, f16_model_part_path):
  84. """Run the quantize script specifying the path to it and the path to the
  85. f16 model to quantize.
  86. """
  87. new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
  88. subprocess.run(
  89. [script_path, f16_model_part_path, new_quantized_model_path, "2"],
  90. check=True
  91. )
  92. if __name__ == "__main__":
  93. try:
  94. main()
  95. except subprocess.CalledProcessError:
  96. print("\nAn error ocurred while trying to quantize the models.")
  97. sys.exit(1)
  98. except KeyboardInterrupt:
  99. sys.exit(0)
  100. else:
  101. print("\nSuccesfully quantized all models.")