fetch_server_test_models.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python
  2. '''
  3. This script fetches all the models used in the server tests.
  4. This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
  5. It is meant to be run from the root of the repository.
  6. Example:
  7. python scripts/fetch_server_test_models.py
  8. ( cd examples/server/tests && ./tests.sh -v -x -m slow )
  9. '''
  10. import ast
  11. import glob
  12. import logging
  13. import os
  14. from typing import Generator
  15. from pydantic import BaseModel
  16. from typing import Optional
  17. import subprocess
  18. class HuggingFaceModel(BaseModel):
  19. hf_repo: str
  20. hf_file: Optional[str] = None
  21. class Config:
  22. frozen = True
  23. def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
  24. try:
  25. with open(test_file) as f:
  26. tree = ast.parse(f.read())
  27. except Exception as e:
  28. logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
  29. return
  30. for node in ast.walk(tree):
  31. if isinstance(node, ast.FunctionDef):
  32. for dec in node.decorator_list:
  33. if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
  34. param_names = ast.literal_eval(dec.args[0]).split(",")
  35. if "hf_repo" not in param_names:
  36. continue
  37. raw_param_values = dec.args[1]
  38. if not isinstance(raw_param_values, ast.List):
  39. logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
  40. continue
  41. hf_repo_idx = param_names.index("hf_repo")
  42. hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
  43. for t in raw_param_values.elts:
  44. if not isinstance(t, ast.Tuple):
  45. logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
  46. continue
  47. yield HuggingFaceModel(
  48. hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
  49. hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
  50. if __name__ == '__main__':
  51. logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
  52. models = sorted(list(set([
  53. model
  54. for test_file in glob.glob('examples/server/tests/unit/test_*.py')
  55. for model in collect_hf_model_test_parameters(test_file)
  56. ])), key=lambda m: (m.hf_repo, m.hf_file))
  57. logging.info(f'Found {len(models)} models in parameterized tests:')
  58. for m in models:
  59. logging.info(f' - {m.hf_repo} / {m.hf_file}')
  60. cli_path = os.environ.get(
  61. 'LLAMA_CLI_BIN_PATH',
  62. os.path.join(
  63. os.path.dirname(__file__),
  64. '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
  65. for m in models:
  66. if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
  67. continue
  68. if m.hf_file is not None and '-of-' in m.hf_file:
  69. logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
  70. continue
  71. logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
  72. cmd = [
  73. cli_path,
  74. '-hfr', m.hf_repo,
  75. *([] if m.hf_file is None else ['-hff', m.hf_file]),
  76. '-n', '1',
  77. '-p', 'Hey',
  78. '--no-warmup',
  79. '--log-disable',
  80. '-no-cnv']
  81. if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
  82. cmd.append('-fa')
  83. try:
  84. subprocess.check_call(cmd)
  85. except subprocess.CalledProcessError:
  86. logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
  87. exit(1)