| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- #!/usr/bin/env python
- '''
- This script fetches all the models used in the server tests.
- This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
- It is meant to be run from the root of the repository.
- Example:
- python scripts/fetch_server_test_models.py
- ( cd tools/server/tests && ./tests.sh -v -x -m slow )
- '''
- import ast
- import glob
- import logging
- import os
- from typing import Generator
- from pydantic import BaseModel
- from typing import Optional
- import subprocess
- class HuggingFaceModel(BaseModel):
- hf_repo: str
- hf_file: Optional[str] = None
- class Config:
- frozen = True
- def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
- try:
- with open(test_file) as f:
- tree = ast.parse(f.read())
- except Exception as e:
- logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
- return
- for node in ast.walk(tree):
- if isinstance(node, ast.FunctionDef):
- for dec in node.decorator_list:
- if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
- param_names = ast.literal_eval(dec.args[0]).split(",")
- if "hf_repo" not in param_names:
- continue
- raw_param_values = dec.args[1]
- if not isinstance(raw_param_values, ast.List):
- logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
- continue
- hf_repo_idx = param_names.index("hf_repo")
- hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
- for t in raw_param_values.elts:
- if not isinstance(t, ast.Tuple):
- logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
- continue
- yield HuggingFaceModel(
- hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
- hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
- if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
- models = sorted(list(set([
- model
- for test_file in glob.glob('tools/server/tests/unit/test_*.py')
- for model in collect_hf_model_test_parameters(test_file)
- ])), key=lambda m: (m.hf_repo, m.hf_file))
- logging.info(f'Found {len(models)} models in parameterized tests:')
- for m in models:
- logging.info(f' - {m.hf_repo} / {m.hf_file}')
- cli_path = os.environ.get(
- 'LLAMA_CLI_BIN_PATH',
- os.path.join(
- os.path.dirname(__file__),
- '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
- for m in models:
- if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
- continue
- if m.hf_file is not None and '-of-' in m.hf_file:
- logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
- continue
- logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
- cmd = [
- cli_path,
- '-hfr', m.hf_repo,
- *([] if m.hf_file is None else ['-hff', m.hf_file]),
- '-n', '1',
- '-p', 'Hey',
- '--no-warmup',
- '--log-disable',
- '-no-cnv']
- if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
- cmd.append('-fa')
- try:
- subprocess.check_call(cmd)
- except subprocess.CalledProcessError:
- logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
- exit(1)
|