utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # type: ignore[reportUnusedImport]
  4. import subprocess
  5. import os
  6. import re
  7. import json
  8. from json import JSONDecodeError
  9. import sys
  10. import requests
  11. import time
  12. from concurrent.futures import ThreadPoolExecutor, as_completed
  13. from typing import (
  14. Any,
  15. Callable,
  16. ContextManager,
  17. Iterable,
  18. Iterator,
  19. List,
  20. Literal,
  21. Tuple,
  22. Set,
  23. )
  24. from re import RegexFlag
  25. import wget
  26. DEFAULT_HTTP_TIMEOUT = 60
  27. class ServerResponse:
  28. headers: dict
  29. status_code: int
  30. body: dict | Any
  31. class ServerError(Exception):
  32. def __init__(self, code, body):
  33. self.code = code
  34. self.body = body
  35. class ServerProcess:
  36. # default options
  37. debug: bool = False
  38. server_port: int = 8080
  39. server_host: str = "127.0.0.1"
  40. model_hf_repo: str | None = "ggml-org/models"
  41. model_hf_file: str | None = "tinyllamas/stories260K.gguf"
  42. model_alias: str = "tinyllama-2"
  43. temperature: float = 0.8
  44. seed: int = 42
  45. offline: bool = False
  46. # custom options
  47. model_alias: str | None = None
  48. model_url: str | None = None
  49. model_file: str | None = None
  50. model_draft: str | None = None
  51. n_threads: int | None = None
  52. n_gpu_layer: int | None = None
  53. n_batch: int | None = None
  54. n_ubatch: int | None = None
  55. n_ctx: int | None = None
  56. n_ga: int | None = None
  57. n_ga_w: int | None = None
  58. n_predict: int | None = None
  59. n_prompts: int | None = 0
  60. slot_save_path: str | None = None
  61. id_slot: int | None = None
  62. cache_prompt: bool | None = None
  63. n_slots: int | None = None
  64. ctk: str | None = None
  65. ctv: str | None = None
  66. fa: str | None = None
  67. server_continuous_batching: bool | None = False
  68. server_embeddings: bool | None = False
  69. server_reranking: bool | None = False
  70. server_metrics: bool | None = False
  71. kv_unified: bool | None = False
  72. server_slots: bool | None = False
  73. pooling: str | None = None
  74. draft: int | None = None
  75. api_key: str | None = None
  76. models_dir: str | None = None
  77. models_max: int | None = None
  78. no_models_autoload: bool | None = None
  79. lora_files: List[str] | None = None
  80. enable_ctx_shift: int | None = False
  81. draft_min: int | None = None
  82. draft_max: int | None = None
  83. no_webui: bool | None = None
  84. jinja: bool | None = None
  85. reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
  86. reasoning_budget: int | None = None
  87. chat_template: str | None = None
  88. chat_template_file: str | None = None
  89. server_path: str | None = None
  90. mmproj_url: str | None = None
  91. media_path: str | None = None
  92. sleep_idle_seconds: int | None = None
  93. # session variables
  94. process: subprocess.Popen | None = None
  95. def __init__(self):
  96. if "N_GPU_LAYERS" in os.environ:
  97. self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
  98. if "DEBUG" in os.environ:
  99. self.debug = True
  100. if "PORT" in os.environ:
  101. self.server_port = int(os.environ["PORT"])
  102. self.external_server = "DEBUG_EXTERNAL" in os.environ
  103. def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
  104. if self.external_server:
  105. print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
  106. return
  107. if self.server_path is not None:
  108. server_path = self.server_path
  109. elif "LLAMA_SERVER_BIN_PATH" in os.environ:
  110. server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
  111. elif os.name == "nt":
  112. server_path = "../../../build/bin/Release/llama-server.exe"
  113. else:
  114. server_path = "../../../build/bin/llama-server"
  115. server_args = [
  116. "--host",
  117. self.server_host,
  118. "--port",
  119. self.server_port,
  120. "--temp",
  121. self.temperature,
  122. "--seed",
  123. self.seed,
  124. ]
  125. if self.offline:
  126. server_args.append("--offline")
  127. if self.model_file:
  128. server_args.extend(["--model", self.model_file])
  129. if self.model_url:
  130. server_args.extend(["--model-url", self.model_url])
  131. if self.model_draft:
  132. server_args.extend(["--model-draft", self.model_draft])
  133. if self.model_hf_repo:
  134. server_args.extend(["--hf-repo", self.model_hf_repo])
  135. if self.model_hf_file:
  136. server_args.extend(["--hf-file", self.model_hf_file])
  137. if self.models_dir:
  138. server_args.extend(["--models-dir", self.models_dir])
  139. if self.models_max is not None:
  140. server_args.extend(["--models-max", self.models_max])
  141. if self.n_batch:
  142. server_args.extend(["--batch-size", self.n_batch])
  143. if self.n_ubatch:
  144. server_args.extend(["--ubatch-size", self.n_ubatch])
  145. if self.n_threads:
  146. server_args.extend(["--threads", self.n_threads])
  147. if self.n_gpu_layer:
  148. server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
  149. if self.draft is not None:
  150. server_args.extend(["--draft", self.draft])
  151. if self.server_continuous_batching:
  152. server_args.append("--cont-batching")
  153. if self.server_embeddings:
  154. server_args.append("--embedding")
  155. if self.server_reranking:
  156. server_args.append("--reranking")
  157. if self.server_metrics:
  158. server_args.append("--metrics")
  159. if self.kv_unified:
  160. server_args.append("--kv-unified")
  161. if self.server_slots:
  162. server_args.append("--slots")
  163. else:
  164. server_args.append("--no-slots")
  165. if self.pooling:
  166. server_args.extend(["--pooling", self.pooling])
  167. if self.model_alias:
  168. server_args.extend(["--alias", self.model_alias])
  169. if self.n_ctx:
  170. server_args.extend(["--ctx-size", self.n_ctx])
  171. if self.n_slots:
  172. server_args.extend(["--parallel", self.n_slots])
  173. if self.ctk:
  174. server_args.extend(["-ctk", self.ctk])
  175. if self.ctv:
  176. server_args.extend(["-ctv", self.ctv])
  177. if self.fa is not None:
  178. server_args.extend(["-fa", self.fa])
  179. if self.n_predict:
  180. server_args.extend(["--n-predict", self.n_predict])
  181. if self.slot_save_path:
  182. server_args.extend(["--slot-save-path", self.slot_save_path])
  183. if self.n_ga:
  184. server_args.extend(["--grp-attn-n", self.n_ga])
  185. if self.n_ga_w:
  186. server_args.extend(["--grp-attn-w", self.n_ga_w])
  187. if self.debug:
  188. server_args.append("--verbose")
  189. if self.lora_files:
  190. for lora_file in self.lora_files:
  191. server_args.extend(["--lora", lora_file])
  192. if self.enable_ctx_shift:
  193. server_args.append("--context-shift")
  194. if self.api_key:
  195. server_args.extend(["--api-key", self.api_key])
  196. if self.draft_max:
  197. server_args.extend(["--draft-max", self.draft_max])
  198. if self.draft_min:
  199. server_args.extend(["--draft-min", self.draft_min])
  200. if self.no_webui:
  201. server_args.append("--no-webui")
  202. if self.no_models_autoload:
  203. server_args.append("--no-models-autoload")
  204. if self.jinja:
  205. server_args.append("--jinja")
  206. else:
  207. server_args.append("--no-jinja")
  208. if self.reasoning_format is not None:
  209. server_args.extend(("--reasoning-format", self.reasoning_format))
  210. if self.reasoning_budget is not None:
  211. server_args.extend(("--reasoning-budget", self.reasoning_budget))
  212. if self.chat_template:
  213. server_args.extend(["--chat-template", self.chat_template])
  214. if self.chat_template_file:
  215. server_args.extend(["--chat-template-file", self.chat_template_file])
  216. if self.mmproj_url:
  217. server_args.extend(["--mmproj-url", self.mmproj_url])
  218. if self.media_path:
  219. server_args.extend(["--media-path", self.media_path])
  220. if self.sleep_idle_seconds is not None:
  221. server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
  222. args = [str(arg) for arg in [server_path, *server_args]]
  223. print(f"tests: starting server with: {' '.join(args)}")
  224. flags = 0
  225. if "nt" == os.name:
  226. flags |= subprocess.DETACHED_PROCESS
  227. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  228. flags |= subprocess.CREATE_NO_WINDOW
  229. self.process = subprocess.Popen(
  230. [str(arg) for arg in [server_path, *server_args]],
  231. creationflags=flags,
  232. stdout=sys.stdout,
  233. stderr=sys.stdout,
  234. env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
  235. )
  236. server_instances.add(self)
  237. print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
  238. # wait for server to start
  239. start_time = time.time()
  240. while time.time() - start_time < timeout_seconds:
  241. try:
  242. response = self.make_request("GET", "/health", headers={
  243. "Authorization": f"Bearer {self.api_key}" if self.api_key else None
  244. })
  245. if response.status_code == 200:
  246. self.ready = True
  247. return # server is ready
  248. except Exception as e:
  249. pass
  250. # Check if process died
  251. if self.process.poll() is not None:
  252. raise RuntimeError(f"Server process died with return code {self.process.returncode}")
  253. print(f"Waiting for server to start...")
  254. time.sleep(0.5)
  255. raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
  256. def stop(self) -> None:
  257. if self.external_server:
  258. print("[external_server]: Not stopping external server")
  259. return
  260. if self in server_instances:
  261. server_instances.remove(self)
  262. if self.process:
  263. print(f"Stopping server with pid={self.process.pid}")
  264. self.process.kill()
  265. self.process = None
  266. def make_request(
  267. self,
  268. method: str,
  269. path: str,
  270. data: dict | Any | None = None,
  271. headers: dict | None = None,
  272. timeout: float | None = None,
  273. ) -> ServerResponse:
  274. url = f"http://{self.server_host}:{self.server_port}{path}"
  275. parse_body = False
  276. if method == "GET":
  277. response = requests.get(url, headers=headers, timeout=timeout)
  278. parse_body = True
  279. elif method == "POST":
  280. response = requests.post(url, headers=headers, json=data, timeout=timeout)
  281. parse_body = True
  282. elif method == "OPTIONS":
  283. response = requests.options(url, headers=headers, timeout=timeout)
  284. else:
  285. raise ValueError(f"Unimplemented method: {method}")
  286. result = ServerResponse()
  287. result.headers = dict(response.headers)
  288. result.status_code = response.status_code
  289. if parse_body:
  290. try:
  291. result.body = response.json()
  292. except JSONDecodeError:
  293. result.body = response.text
  294. else:
  295. result.body = None
  296. print("Response from server", json.dumps(result.body, indent=2))
  297. return result
  298. def make_stream_request(
  299. self,
  300. method: str,
  301. path: str,
  302. data: dict | None = None,
  303. headers: dict | None = None,
  304. ) -> Iterator[dict]:
  305. url = f"http://{self.server_host}:{self.server_port}{path}"
  306. if method == "POST":
  307. response = requests.post(url, headers=headers, json=data, stream=True)
  308. else:
  309. raise ValueError(f"Unimplemented method: {method}")
  310. if response.status_code != 200:
  311. raise ServerError(response.status_code, response.json())
  312. for line_bytes in response.iter_lines():
  313. line = line_bytes.decode("utf-8")
  314. if '[DONE]' in line:
  315. break
  316. elif line.startswith('data: '):
  317. data = json.loads(line[6:])
  318. print("Partial response from server", json.dumps(data, indent=2))
  319. yield data
  320. def make_any_request(
  321. self,
  322. method: str,
  323. path: str,
  324. data: dict | None = None,
  325. headers: dict | None = None,
  326. timeout: float | None = None,
  327. ) -> dict:
  328. stream = data.get('stream', False)
  329. if stream:
  330. content: list[str] = []
  331. reasoning_content: list[str] = []
  332. tool_calls: list[dict] = []
  333. finish_reason: Optional[str] = None
  334. content_parts = 0
  335. reasoning_content_parts = 0
  336. tool_call_parts = 0
  337. arguments_parts = 0
  338. for chunk in self.make_stream_request(method, path, data, headers):
  339. if chunk['choices']:
  340. assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
  341. choice = chunk['choices'][0]
  342. if choice['delta'].get('content') is not None:
  343. assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
  344. content.append(choice['delta']['content'])
  345. content_parts += 1
  346. if choice['delta'].get('reasoning_content') is not None:
  347. assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
  348. reasoning_content.append(choice['delta']['reasoning_content'])
  349. reasoning_content_parts += 1
  350. if choice['delta'].get('finish_reason') is not None:
  351. finish_reason = choice['delta']['finish_reason']
  352. for tc in choice['delta'].get('tool_calls', []):
  353. if 'function' not in tc:
  354. raise ValueError(f"Expected function type, got {tc['type']}")
  355. if tc['index'] >= len(tool_calls):
  356. assert 'id' in tc
  357. assert tc.get('type') == 'function'
  358. assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
  359. f"Expected function call with name, got {tc.get('function')}"
  360. tool_calls.append(dict(
  361. id="",
  362. type="function",
  363. function=dict(
  364. name="",
  365. arguments="",
  366. )
  367. ))
  368. tool_call = tool_calls[tc['index']]
  369. if tc.get('id') is not None:
  370. tool_call['id'] = tc['id']
  371. fct = tc['function']
  372. assert 'id' not in fct, f"Function call should not have id: {fct}"
  373. if fct.get('name') is not None:
  374. tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
  375. if fct.get('arguments') is not None:
  376. tool_call['function']['arguments'] += fct['arguments']
  377. arguments_parts += 1
  378. tool_call_parts += 1
  379. else:
  380. # When `include_usage` is True (the default), we expect the last chunk of the stream
  381. # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
  382. # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
  383. # the last chunk)
  384. assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
  385. assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
  386. print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
  387. result = dict(
  388. choices=[
  389. dict(
  390. index=0,
  391. finish_reason=finish_reason,
  392. message=dict(
  393. role='assistant',
  394. content=''.join(content) if content else None,
  395. reasoning_content=''.join(reasoning_content) if reasoning_content else None,
  396. tool_calls=tool_calls if tool_calls else None,
  397. ),
  398. )
  399. ],
  400. )
  401. print("Final response from server", json.dumps(result, indent=2))
  402. return result
  403. else:
  404. response = self.make_request(method, path, data, headers, timeout=timeout)
  405. assert response.status_code == 200, f"Server returned error: {response.status_code}"
  406. return response.body
  407. server_instances: Set[ServerProcess] = set()
  408. class ServerPreset:
  409. @staticmethod
  410. def load_all() -> None:
  411. """ Load all server presets to ensure model files are cached. """
  412. servers: List[ServerProcess] = [
  413. method()
  414. for name, method in ServerPreset.__dict__.items()
  415. if callable(method) and name != "load_all"
  416. ]
  417. for server in servers:
  418. server.offline = False
  419. server.start()
  420. server.stop()
  421. @staticmethod
  422. def tinyllama2() -> ServerProcess:
  423. server = ServerProcess()
  424. server.offline = True # will be downloaded by load_all()
  425. server.model_hf_repo = "ggml-org/test-model-stories260K"
  426. server.model_hf_file = None
  427. server.model_alias = "tinyllama-2"
  428. server.n_ctx = 512
  429. server.n_batch = 32
  430. server.n_slots = 2
  431. server.n_predict = 64
  432. server.seed = 42
  433. return server
  434. @staticmethod
  435. def bert_bge_small() -> ServerProcess:
  436. server = ServerProcess()
  437. server.offline = True # will be downloaded by load_all()
  438. server.model_hf_repo = "ggml-org/models"
  439. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  440. server.model_alias = "bert-bge-small"
  441. server.n_ctx = 512
  442. server.n_batch = 128
  443. server.n_ubatch = 128
  444. server.n_slots = 2
  445. server.seed = 42
  446. server.server_embeddings = True
  447. return server
  448. @staticmethod
  449. def bert_bge_small_with_fa() -> ServerProcess:
  450. server = ServerProcess()
  451. server.offline = True # will be downloaded by load_all()
  452. server.model_hf_repo = "ggml-org/models"
  453. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  454. server.model_alias = "bert-bge-small"
  455. server.n_ctx = 1024
  456. server.n_batch = 300
  457. server.n_ubatch = 300
  458. server.n_slots = 2
  459. server.fa = "on"
  460. server.seed = 42
  461. server.server_embeddings = True
  462. return server
  463. @staticmethod
  464. def tinyllama_infill() -> ServerProcess:
  465. server = ServerProcess()
  466. server.offline = True # will be downloaded by load_all()
  467. server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
  468. server.model_hf_file = None
  469. server.model_alias = "tinyllama-infill"
  470. server.n_ctx = 2048
  471. server.n_batch = 1024
  472. server.n_slots = 1
  473. server.n_predict = 64
  474. server.temperature = 0.0
  475. server.seed = 42
  476. return server
  477. @staticmethod
  478. def stories15m_moe() -> ServerProcess:
  479. server = ServerProcess()
  480. server.offline = True # will be downloaded by load_all()
  481. server.model_hf_repo = "ggml-org/stories15M_MOE"
  482. server.model_hf_file = "stories15M_MOE-F16.gguf"
  483. server.model_alias = "stories15m-moe"
  484. server.n_ctx = 2048
  485. server.n_batch = 1024
  486. server.n_slots = 1
  487. server.n_predict = 64
  488. server.temperature = 0.0
  489. server.seed = 42
  490. return server
  491. @staticmethod
  492. def jina_reranker_tiny() -> ServerProcess:
  493. server = ServerProcess()
  494. server.offline = True # will be downloaded by load_all()
  495. server.model_hf_repo = "ggml-org/models"
  496. server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
  497. server.model_alias = "jina-reranker"
  498. server.n_ctx = 512
  499. server.n_batch = 512
  500. server.n_slots = 1
  501. server.seed = 42
  502. server.server_reranking = True
  503. return server
  504. @staticmethod
  505. def tinygemma3() -> ServerProcess:
  506. server = ServerProcess()
  507. server.offline = True # will be downloaded by load_all()
  508. # mmproj is already provided by HF registry API
  509. server.model_hf_file = None
  510. server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
  511. server.model_alias = "tinygemma3"
  512. server.n_ctx = 1024
  513. server.n_batch = 32
  514. server.n_slots = 2
  515. server.n_predict = 4
  516. server.seed = 42
  517. return server
  518. @staticmethod
  519. def router() -> ServerProcess:
  520. server = ServerProcess()
  521. server.offline = True # will be downloaded by load_all()
  522. # router server has no models
  523. server.model_file = None
  524. server.model_alias = None
  525. server.model_hf_repo = None
  526. server.model_hf_file = None
  527. server.n_ctx = 1024
  528. server.n_batch = 16
  529. server.n_slots = 1
  530. server.n_predict = 16
  531. server.seed = 42
  532. return server
  533. def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
  534. """
  535. Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
  536. Example usage:
  537. results = parallel_function_calls([
  538. (func1, (arg1, arg2)),
  539. (func2, (arg3, arg4)),
  540. ])
  541. """
  542. results = [None] * len(function_list)
  543. exceptions = []
  544. def worker(index, func, args):
  545. try:
  546. result = func(*args)
  547. results[index] = result
  548. except Exception as e:
  549. exceptions.append((index, str(e)))
  550. with ThreadPoolExecutor() as executor:
  551. futures = []
  552. for i, (func, args) in enumerate(function_list):
  553. future = executor.submit(worker, i, func, args)
  554. futures.append(future)
  555. # Wait for all futures to complete
  556. for future in as_completed(futures):
  557. pass
  558. # Check if there were any exceptions
  559. if exceptions:
  560. print("Exceptions occurred:")
  561. for index, error in exceptions:
  562. print(f"Function at index {index}: {error}")
  563. return results
  564. def match_regex(regex: str, text: str) -> bool:
  565. return (
  566. re.compile(
  567. regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
  568. ).search(text)
  569. is not None
  570. )
  571. def download_file(url: str, output_file_path: str | None = None) -> str:
  572. """
  573. Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
  574. output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
  575. Returns the local path of the downloaded file.
  576. """
  577. file_name = url.split('/').pop()
  578. output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
  579. if not os.path.exists(output_file):
  580. print(f"Downloading {url} to {output_file}")
  581. wget.download(url, out=output_file)
  582. print(f"Done downloading to {output_file}")
  583. else:
  584. print(f"File already exists at {output_file}")
  585. return output_file
  586. def is_slow_test_allowed():
  587. return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"