utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # type: ignore[reportUnusedImport]
  4. import subprocess
  5. import os
  6. import re
  7. import json
  8. import sys
  9. import requests
  10. import time
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from typing import (
  13. Any,
  14. Callable,
  15. ContextManager,
  16. Iterable,
  17. Iterator,
  18. List,
  19. Literal,
  20. Tuple,
  21. Set,
  22. )
  23. from re import RegexFlag
  24. import wget
  25. DEFAULT_HTTP_TIMEOUT = 30
  26. class ServerResponse:
  27. headers: dict
  28. status_code: int
  29. body: dict | Any
  30. class ServerProcess:
  31. # default options
  32. debug: bool = False
  33. server_port: int = 8080
  34. server_host: str = "127.0.0.1"
  35. model_hf_repo: str = "ggml-org/models"
  36. model_hf_file: str | None = "tinyllamas/stories260K.gguf"
  37. model_alias: str = "tinyllama-2"
  38. temperature: float = 0.8
  39. seed: int = 42
  40. # custom options
  41. model_alias: str | None = None
  42. model_url: str | None = None
  43. model_file: str | None = None
  44. model_draft: str | None = None
  45. n_threads: int | None = None
  46. n_gpu_layer: int | None = None
  47. n_batch: int | None = None
  48. n_ubatch: int | None = None
  49. n_ctx: int | None = None
  50. n_ga: int | None = None
  51. n_ga_w: int | None = None
  52. n_predict: int | None = None
  53. n_prompts: int | None = 0
  54. slot_save_path: str | None = None
  55. id_slot: int | None = None
  56. cache_prompt: bool | None = None
  57. n_slots: int | None = None
  58. ctk: str | None = None
  59. ctv: str | None = None
  60. fa: bool | None = None
  61. server_continuous_batching: bool | None = False
  62. server_embeddings: bool | None = False
  63. server_reranking: bool | None = False
  64. server_metrics: bool | None = False
  65. server_slots: bool | None = False
  66. pooling: str | None = None
  67. draft: int | None = None
  68. api_key: str | None = None
  69. lora_files: List[str] | None = None
  70. enable_ctx_shift: int | None = False
  71. draft_min: int | None = None
  72. draft_max: int | None = None
  73. no_webui: bool | None = None
  74. jinja: bool | None = None
  75. reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
  76. reasoning_budget: int | None = None
  77. chat_template: str | None = None
  78. chat_template_file: str | None = None
  79. server_path: str | None = None
  80. mmproj_url: str | None = None
  81. # session variables
  82. process: subprocess.Popen | None = None
  83. def __init__(self):
  84. if "N_GPU_LAYERS" in os.environ:
  85. self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
  86. if "DEBUG" in os.environ:
  87. self.debug = True
  88. if "PORT" in os.environ:
  89. self.server_port = int(os.environ["PORT"])
  90. def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
  91. if self.server_path is not None:
  92. server_path = self.server_path
  93. elif "LLAMA_SERVER_BIN_PATH" in os.environ:
  94. server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
  95. elif os.name == "nt":
  96. server_path = "../../../build/bin/Release/llama-server.exe"
  97. else:
  98. server_path = "../../../build/bin/llama-server"
  99. server_args = [
  100. "--host",
  101. self.server_host,
  102. "--port",
  103. self.server_port,
  104. "--temp",
  105. self.temperature,
  106. "--seed",
  107. self.seed,
  108. ]
  109. if self.model_file:
  110. server_args.extend(["--model", self.model_file])
  111. if self.model_url:
  112. server_args.extend(["--model-url", self.model_url])
  113. if self.model_draft:
  114. server_args.extend(["--model-draft", self.model_draft])
  115. if self.model_hf_repo:
  116. server_args.extend(["--hf-repo", self.model_hf_repo])
  117. if self.model_hf_file:
  118. server_args.extend(["--hf-file", self.model_hf_file])
  119. if self.n_batch:
  120. server_args.extend(["--batch-size", self.n_batch])
  121. if self.n_ubatch:
  122. server_args.extend(["--ubatch-size", self.n_ubatch])
  123. if self.n_threads:
  124. server_args.extend(["--threads", self.n_threads])
  125. if self.n_gpu_layer:
  126. server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
  127. if self.draft is not None:
  128. server_args.extend(["--draft", self.draft])
  129. if self.server_continuous_batching:
  130. server_args.append("--cont-batching")
  131. if self.server_embeddings:
  132. server_args.append("--embedding")
  133. if self.server_reranking:
  134. server_args.append("--reranking")
  135. if self.server_metrics:
  136. server_args.append("--metrics")
  137. if self.server_slots:
  138. server_args.append("--slots")
  139. if self.pooling:
  140. server_args.extend(["--pooling", self.pooling])
  141. if self.model_alias:
  142. server_args.extend(["--alias", self.model_alias])
  143. if self.n_ctx:
  144. server_args.extend(["--ctx-size", self.n_ctx])
  145. if self.n_slots:
  146. server_args.extend(["--parallel", self.n_slots])
  147. if self.ctk:
  148. server_args.extend(["-ctk", self.ctk])
  149. if self.ctv:
  150. server_args.extend(["-ctv", self.ctv])
  151. if self.fa is not None:
  152. server_args.append("-fa")
  153. if self.n_predict:
  154. server_args.extend(["--n-predict", self.n_predict])
  155. if self.slot_save_path:
  156. server_args.extend(["--slot-save-path", self.slot_save_path])
  157. if self.n_ga:
  158. server_args.extend(["--grp-attn-n", self.n_ga])
  159. if self.n_ga_w:
  160. server_args.extend(["--grp-attn-w", self.n_ga_w])
  161. if self.debug:
  162. server_args.append("--verbose")
  163. if self.lora_files:
  164. for lora_file in self.lora_files:
  165. server_args.extend(["--lora", lora_file])
  166. if self.enable_ctx_shift:
  167. server_args.append("--context-shift")
  168. if self.api_key:
  169. server_args.extend(["--api-key", self.api_key])
  170. if self.draft_max:
  171. server_args.extend(["--draft-max", self.draft_max])
  172. if self.draft_min:
  173. server_args.extend(["--draft-min", self.draft_min])
  174. if self.no_webui:
  175. server_args.append("--no-webui")
  176. if self.jinja:
  177. server_args.append("--jinja")
  178. if self.reasoning_format is not None:
  179. server_args.extend(("--reasoning-format", self.reasoning_format))
  180. if self.reasoning_budget is not None:
  181. server_args.extend(("--reasoning-budget", self.reasoning_budget))
  182. if self.chat_template:
  183. server_args.extend(["--chat-template", self.chat_template])
  184. if self.chat_template_file:
  185. server_args.extend(["--chat-template-file", self.chat_template_file])
  186. if self.mmproj_url:
  187. server_args.extend(["--mmproj-url", self.mmproj_url])
  188. args = [str(arg) for arg in [server_path, *server_args]]
  189. print(f"tests: starting server with: {' '.join(args)}")
  190. flags = 0
  191. if "nt" == os.name:
  192. flags |= subprocess.DETACHED_PROCESS
  193. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  194. flags |= subprocess.CREATE_NO_WINDOW
  195. self.process = subprocess.Popen(
  196. [str(arg) for arg in [server_path, *server_args]],
  197. creationflags=flags,
  198. stdout=sys.stdout,
  199. stderr=sys.stdout,
  200. env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
  201. )
  202. server_instances.add(self)
  203. print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
  204. # wait for server to start
  205. start_time = time.time()
  206. while time.time() - start_time < timeout_seconds:
  207. try:
  208. response = self.make_request("GET", "/health", headers={
  209. "Authorization": f"Bearer {self.api_key}" if self.api_key else None
  210. })
  211. if response.status_code == 200:
  212. self.ready = True
  213. return # server is ready
  214. except Exception as e:
  215. pass
  216. # Check if process died
  217. if self.process.poll() is not None:
  218. raise RuntimeError(f"Server process died with return code {self.process.returncode}")
  219. print(f"Waiting for server to start...")
  220. time.sleep(0.5)
  221. raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
  222. def stop(self) -> None:
  223. if self in server_instances:
  224. server_instances.remove(self)
  225. if self.process:
  226. print(f"Stopping server with pid={self.process.pid}")
  227. self.process.kill()
  228. self.process = None
  229. def make_request(
  230. self,
  231. method: str,
  232. path: str,
  233. data: dict | Any | None = None,
  234. headers: dict | None = None,
  235. timeout: float | None = None,
  236. ) -> ServerResponse:
  237. url = f"http://{self.server_host}:{self.server_port}{path}"
  238. parse_body = False
  239. if method == "GET":
  240. response = requests.get(url, headers=headers, timeout=timeout)
  241. parse_body = True
  242. elif method == "POST":
  243. response = requests.post(url, headers=headers, json=data, timeout=timeout)
  244. parse_body = True
  245. elif method == "OPTIONS":
  246. response = requests.options(url, headers=headers, timeout=timeout)
  247. else:
  248. raise ValueError(f"Unimplemented method: {method}")
  249. result = ServerResponse()
  250. result.headers = dict(response.headers)
  251. result.status_code = response.status_code
  252. result.body = response.json() if parse_body else None
  253. print("Response from server", json.dumps(result.body, indent=2))
  254. return result
  255. def make_stream_request(
  256. self,
  257. method: str,
  258. path: str,
  259. data: dict | None = None,
  260. headers: dict | None = None,
  261. ) -> Iterator[dict]:
  262. url = f"http://{self.server_host}:{self.server_port}{path}"
  263. if method == "POST":
  264. response = requests.post(url, headers=headers, json=data, stream=True)
  265. else:
  266. raise ValueError(f"Unimplemented method: {method}")
  267. for line_bytes in response.iter_lines():
  268. line = line_bytes.decode("utf-8")
  269. if '[DONE]' in line:
  270. break
  271. elif line.startswith('data: '):
  272. data = json.loads(line[6:])
  273. print("Partial response from server", json.dumps(data, indent=2))
  274. yield data
  275. def make_any_request(
  276. self,
  277. method: str,
  278. path: str,
  279. data: dict | None = None,
  280. headers: dict | None = None,
  281. timeout: float | None = None,
  282. ) -> dict:
  283. stream = data.get('stream', False)
  284. if stream:
  285. content: list[str] = []
  286. reasoning_content: list[str] = []
  287. tool_calls: list[dict] = []
  288. finish_reason: Optional[str] = None
  289. content_parts = 0
  290. reasoning_content_parts = 0
  291. tool_call_parts = 0
  292. arguments_parts = 0
  293. for chunk in self.make_stream_request(method, path, data, headers):
  294. if chunk['choices']:
  295. assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
  296. choice = chunk['choices'][0]
  297. if choice['delta'].get('content') is not None:
  298. assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
  299. content.append(choice['delta']['content'])
  300. content_parts += 1
  301. if choice['delta'].get('reasoning_content') is not None:
  302. assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
  303. reasoning_content.append(choice['delta']['reasoning_content'])
  304. reasoning_content_parts += 1
  305. if choice['delta'].get('finish_reason') is not None:
  306. finish_reason = choice['delta']['finish_reason']
  307. for tc in choice['delta'].get('tool_calls', []):
  308. if 'function' not in tc:
  309. raise ValueError(f"Expected function type, got {tc['type']}")
  310. if tc['index'] >= len(tool_calls):
  311. assert 'id' in tc
  312. assert tc.get('type') == 'function'
  313. assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
  314. f"Expected function call with name, got {tc.get('function')}"
  315. tool_calls.append(dict(
  316. id="",
  317. type="function",
  318. function=dict(
  319. name="",
  320. arguments="",
  321. )
  322. ))
  323. tool_call = tool_calls[tc['index']]
  324. if tc.get('id') is not None:
  325. tool_call['id'] = tc['id']
  326. fct = tc['function']
  327. assert 'id' not in fct, f"Function call should not have id: {fct}"
  328. if fct.get('name') is not None:
  329. tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
  330. if fct.get('arguments') is not None:
  331. tool_call['function']['arguments'] += fct['arguments']
  332. arguments_parts += 1
  333. tool_call_parts += 1
  334. else:
  335. # When `include_usage` is True (the default), we expect the last chunk of the stream
  336. # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
  337. # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
  338. # the last chunk)
  339. assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
  340. assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
  341. print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
  342. result = dict(
  343. choices=[
  344. dict(
  345. index=0,
  346. finish_reason=finish_reason,
  347. message=dict(
  348. role='assistant',
  349. content=''.join(content) if content else None,
  350. reasoning_content=''.join(reasoning_content) if reasoning_content else None,
  351. tool_calls=tool_calls if tool_calls else None,
  352. ),
  353. )
  354. ],
  355. )
  356. print("Final response from server", json.dumps(result, indent=2))
  357. return result
  358. else:
  359. response = self.make_request(method, path, data, headers, timeout=timeout)
  360. assert response.status_code == 200, f"Server returned error: {response.status_code}"
  361. return response.body
  362. server_instances: Set[ServerProcess] = set()
  363. class ServerPreset:
  364. @staticmethod
  365. def tinyllama2() -> ServerProcess:
  366. server = ServerProcess()
  367. server.model_hf_repo = "ggml-org/models"
  368. server.model_hf_file = "tinyllamas/stories260K.gguf"
  369. server.model_alias = "tinyllama-2"
  370. server.n_ctx = 512
  371. server.n_batch = 32
  372. server.n_slots = 2
  373. server.n_predict = 64
  374. server.seed = 42
  375. return server
  376. @staticmethod
  377. def bert_bge_small() -> ServerProcess:
  378. server = ServerProcess()
  379. server.model_hf_repo = "ggml-org/models"
  380. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  381. server.model_alias = "bert-bge-small"
  382. server.n_ctx = 512
  383. server.n_batch = 128
  384. server.n_ubatch = 128
  385. server.n_slots = 2
  386. server.seed = 42
  387. server.server_embeddings = True
  388. return server
  389. @staticmethod
  390. def bert_bge_small_with_fa() -> ServerProcess:
  391. server = ServerProcess()
  392. server.model_hf_repo = "ggml-org/models"
  393. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  394. server.model_alias = "bert-bge-small"
  395. server.n_ctx = 1024
  396. server.n_batch = 300
  397. server.n_ubatch = 300
  398. server.n_slots = 2
  399. server.fa = True
  400. server.seed = 42
  401. server.server_embeddings = True
  402. return server
  403. @staticmethod
  404. def tinyllama_infill() -> ServerProcess:
  405. server = ServerProcess()
  406. server.model_hf_repo = "ggml-org/models"
  407. server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
  408. server.model_alias = "tinyllama-infill"
  409. server.n_ctx = 2048
  410. server.n_batch = 1024
  411. server.n_slots = 1
  412. server.n_predict = 64
  413. server.temperature = 0.0
  414. server.seed = 42
  415. return server
  416. @staticmethod
  417. def stories15m_moe() -> ServerProcess:
  418. server = ServerProcess()
  419. server.model_hf_repo = "ggml-org/stories15M_MOE"
  420. server.model_hf_file = "stories15M_MOE-F16.gguf"
  421. server.model_alias = "stories15m-moe"
  422. server.n_ctx = 2048
  423. server.n_batch = 1024
  424. server.n_slots = 1
  425. server.n_predict = 64
  426. server.temperature = 0.0
  427. server.seed = 42
  428. return server
  429. @staticmethod
  430. def jina_reranker_tiny() -> ServerProcess:
  431. server = ServerProcess()
  432. server.model_hf_repo = "ggml-org/models"
  433. server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
  434. server.model_alias = "jina-reranker"
  435. server.n_ctx = 512
  436. server.n_batch = 512
  437. server.n_slots = 1
  438. server.seed = 42
  439. server.server_reranking = True
  440. return server
  441. @staticmethod
  442. def tinygemma3() -> ServerProcess:
  443. server = ServerProcess()
  444. # mmproj is already provided by HF registry API
  445. server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
  446. server.model_hf_file = "tinygemma3-Q8_0.gguf"
  447. server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf"
  448. server.model_alias = "tinygemma3"
  449. server.n_ctx = 1024
  450. server.n_batch = 32
  451. server.n_slots = 2
  452. server.n_predict = 4
  453. server.seed = 42
  454. return server
  455. def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
  456. """
  457. Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
  458. Example usage:
  459. results = parallel_function_calls([
  460. (func1, (arg1, arg2)),
  461. (func2, (arg3, arg4)),
  462. ])
  463. """
  464. results = [None] * len(function_list)
  465. exceptions = []
  466. def worker(index, func, args):
  467. try:
  468. result = func(*args)
  469. results[index] = result
  470. except Exception as e:
  471. exceptions.append((index, str(e)))
  472. with ThreadPoolExecutor() as executor:
  473. futures = []
  474. for i, (func, args) in enumerate(function_list):
  475. future = executor.submit(worker, i, func, args)
  476. futures.append(future)
  477. # Wait for all futures to complete
  478. for future in as_completed(futures):
  479. pass
  480. # Check if there were any exceptions
  481. if exceptions:
  482. print("Exceptions occurred:")
  483. for index, error in exceptions:
  484. print(f"Function at index {index}: {error}")
  485. return results
  486. def match_regex(regex: str, text: str) -> bool:
  487. return (
  488. re.compile(
  489. regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
  490. ).search(text)
  491. is not None
  492. )
  493. def download_file(url: str, output_file_path: str | None = None) -> str:
  494. """
  495. Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
  496. output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
  497. Returns the local path of the downloaded file.
  498. """
  499. file_name = url.split('/').pop()
  500. output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
  501. if not os.path.exists(output_file):
  502. print(f"Downloading {url} to {output_file}")
  503. wget.download(url, out=output_file)
  504. print(f"Done downloading to {output_file}")
  505. else:
  506. print(f"File already exists at {output_file}")
  507. return output_file
  508. def is_slow_test_allowed():
  509. return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"