utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # type: ignore[reportUnusedImport]
  4. import subprocess
  5. import os
  6. import re
  7. import json
  8. import sys
  9. import requests
  10. import time
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from typing import (
  13. Any,
  14. Callable,
  15. ContextManager,
  16. Iterable,
  17. Iterator,
  18. List,
  19. Literal,
  20. Tuple,
  21. Set,
  22. )
  23. from re import RegexFlag
  24. import wget
  25. DEFAULT_HTTP_TIMEOUT = 12
  26. if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ:
  27. DEFAULT_HTTP_TIMEOUT = 30
  28. class ServerResponse:
  29. headers: dict
  30. status_code: int
  31. body: dict | Any
  32. class ServerProcess:
  33. # default options
  34. debug: bool = False
  35. server_port: int = 8080
  36. server_host: str = "127.0.0.1"
  37. model_hf_repo: str = "ggml-org/models"
  38. model_hf_file: str | None = "tinyllamas/stories260K.gguf"
  39. model_alias: str = "tinyllama-2"
  40. temperature: float = 0.8
  41. seed: int = 42
  42. # custom options
  43. model_alias: str | None = None
  44. model_url: str | None = None
  45. model_file: str | None = None
  46. model_draft: str | None = None
  47. n_threads: int | None = None
  48. n_gpu_layer: int | None = None
  49. n_batch: int | None = None
  50. n_ubatch: int | None = None
  51. n_ctx: int | None = None
  52. n_ga: int | None = None
  53. n_ga_w: int | None = None
  54. n_predict: int | None = None
  55. n_prompts: int | None = 0
  56. slot_save_path: str | None = None
  57. id_slot: int | None = None
  58. cache_prompt: bool | None = None
  59. n_slots: int | None = None
  60. server_continuous_batching: bool | None = False
  61. server_embeddings: bool | None = False
  62. server_reranking: bool | None = False
  63. server_metrics: bool | None = False
  64. server_slots: bool | None = False
  65. pooling: str | None = None
  66. draft: int | None = None
  67. api_key: str | None = None
  68. lora_files: List[str] | None = None
  69. disable_ctx_shift: int | None = False
  70. draft_min: int | None = None
  71. draft_max: int | None = None
  72. no_webui: bool | None = None
  73. jinja: bool | None = None
  74. reasoning_format: Literal['deepseek', 'none'] | None = None
  75. chat_template: str | None = None
  76. chat_template_file: str | None = None
  77. # session variables
  78. process: subprocess.Popen | None = None
  79. def __init__(self):
  80. if "N_GPU_LAYERS" in os.environ:
  81. self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
  82. if "DEBUG" in os.environ:
  83. self.debug = True
  84. if "PORT" in os.environ:
  85. self.server_port = int(os.environ["PORT"])
  86. def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
  87. if "LLAMA_SERVER_BIN_PATH" in os.environ:
  88. server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
  89. elif os.name == "nt":
  90. server_path = "../../../build/bin/Release/llama-server.exe"
  91. else:
  92. server_path = "../../../build/bin/llama-server"
  93. server_args = [
  94. "--host",
  95. self.server_host,
  96. "--port",
  97. self.server_port,
  98. "--temp",
  99. self.temperature,
  100. "--seed",
  101. self.seed,
  102. ]
  103. if self.model_file:
  104. server_args.extend(["--model", self.model_file])
  105. if self.model_url:
  106. server_args.extend(["--model-url", self.model_url])
  107. if self.model_draft:
  108. server_args.extend(["--model-draft", self.model_draft])
  109. if self.model_hf_repo:
  110. server_args.extend(["--hf-repo", self.model_hf_repo])
  111. if self.model_hf_file:
  112. server_args.extend(["--hf-file", self.model_hf_file])
  113. if self.n_batch:
  114. server_args.extend(["--batch-size", self.n_batch])
  115. if self.n_ubatch:
  116. server_args.extend(["--ubatch-size", self.n_ubatch])
  117. if self.n_threads:
  118. server_args.extend(["--threads", self.n_threads])
  119. if self.n_gpu_layer:
  120. server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
  121. if self.draft is not None:
  122. server_args.extend(["--draft", self.draft])
  123. if self.server_continuous_batching:
  124. server_args.append("--cont-batching")
  125. if self.server_embeddings:
  126. server_args.append("--embedding")
  127. if self.server_reranking:
  128. server_args.append("--reranking")
  129. if self.server_metrics:
  130. server_args.append("--metrics")
  131. if self.server_slots:
  132. server_args.append("--slots")
  133. if self.pooling:
  134. server_args.extend(["--pooling", self.pooling])
  135. if self.model_alias:
  136. server_args.extend(["--alias", self.model_alias])
  137. if self.n_ctx:
  138. server_args.extend(["--ctx-size", self.n_ctx])
  139. if self.n_slots:
  140. server_args.extend(["--parallel", self.n_slots])
  141. if self.n_predict:
  142. server_args.extend(["--n-predict", self.n_predict])
  143. if self.slot_save_path:
  144. server_args.extend(["--slot-save-path", self.slot_save_path])
  145. if self.n_ga:
  146. server_args.extend(["--grp-attn-n", self.n_ga])
  147. if self.n_ga_w:
  148. server_args.extend(["--grp-attn-w", self.n_ga_w])
  149. if self.debug:
  150. server_args.append("--verbose")
  151. if self.lora_files:
  152. for lora_file in self.lora_files:
  153. server_args.extend(["--lora", lora_file])
  154. if self.disable_ctx_shift:
  155. server_args.extend(["--no-context-shift"])
  156. if self.api_key:
  157. server_args.extend(["--api-key", self.api_key])
  158. if self.draft_max:
  159. server_args.extend(["--draft-max", self.draft_max])
  160. if self.draft_min:
  161. server_args.extend(["--draft-min", self.draft_min])
  162. if self.no_webui:
  163. server_args.append("--no-webui")
  164. if self.jinja:
  165. server_args.append("--jinja")
  166. if self.reasoning_format is not None:
  167. server_args.extend(("--reasoning-format", self.reasoning_format))
  168. if self.chat_template:
  169. server_args.extend(["--chat-template", self.chat_template])
  170. if self.chat_template_file:
  171. server_args.extend(["--chat-template-file", self.chat_template_file])
  172. args = [str(arg) for arg in [server_path, *server_args]]
  173. print(f"bench: starting server with: {' '.join(args)}")
  174. flags = 0
  175. if "nt" == os.name:
  176. flags |= subprocess.DETACHED_PROCESS
  177. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  178. flags |= subprocess.CREATE_NO_WINDOW
  179. self.process = subprocess.Popen(
  180. [str(arg) for arg in [server_path, *server_args]],
  181. creationflags=flags,
  182. stdout=sys.stdout,
  183. stderr=sys.stdout,
  184. env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
  185. )
  186. server_instances.add(self)
  187. print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
  188. # wait for server to start
  189. start_time = time.time()
  190. while time.time() - start_time < timeout_seconds:
  191. try:
  192. response = self.make_request("GET", "/health", headers={
  193. "Authorization": f"Bearer {self.api_key}" if self.api_key else None
  194. })
  195. if response.status_code == 200:
  196. self.ready = True
  197. return # server is ready
  198. except Exception as e:
  199. pass
  200. print(f"Waiting for server to start...")
  201. time.sleep(0.5)
  202. raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
  203. def stop(self) -> None:
  204. if self in server_instances:
  205. server_instances.remove(self)
  206. if self.process:
  207. print(f"Stopping server with pid={self.process.pid}")
  208. self.process.kill()
  209. self.process = None
  210. def make_request(
  211. self,
  212. method: str,
  213. path: str,
  214. data: dict | Any | None = None,
  215. headers: dict | None = None,
  216. timeout: float | None = None,
  217. ) -> ServerResponse:
  218. url = f"http://{self.server_host}:{self.server_port}{path}"
  219. parse_body = False
  220. if method == "GET":
  221. response = requests.get(url, headers=headers, timeout=timeout)
  222. parse_body = True
  223. elif method == "POST":
  224. response = requests.post(url, headers=headers, json=data, timeout=timeout)
  225. parse_body = True
  226. elif method == "OPTIONS":
  227. response = requests.options(url, headers=headers, timeout=timeout)
  228. else:
  229. raise ValueError(f"Unimplemented method: {method}")
  230. result = ServerResponse()
  231. result.headers = dict(response.headers)
  232. result.status_code = response.status_code
  233. result.body = response.json() if parse_body else None
  234. print("Response from server", json.dumps(result.body, indent=2))
  235. return result
  236. def make_stream_request(
  237. self,
  238. method: str,
  239. path: str,
  240. data: dict | None = None,
  241. headers: dict | None = None,
  242. ) -> Iterator[dict]:
  243. url = f"http://{self.server_host}:{self.server_port}{path}"
  244. if method == "POST":
  245. response = requests.post(url, headers=headers, json=data, stream=True)
  246. else:
  247. raise ValueError(f"Unimplemented method: {method}")
  248. for line_bytes in response.iter_lines():
  249. line = line_bytes.decode("utf-8")
  250. if '[DONE]' in line:
  251. break
  252. elif line.startswith('data: '):
  253. data = json.loads(line[6:])
  254. print("Partial response from server", json.dumps(data, indent=2))
  255. yield data
  256. server_instances: Set[ServerProcess] = set()
  257. class ServerPreset:
  258. @staticmethod
  259. def tinyllama2() -> ServerProcess:
  260. server = ServerProcess()
  261. server.model_hf_repo = "ggml-org/models"
  262. server.model_hf_file = "tinyllamas/stories260K.gguf"
  263. server.model_alias = "tinyllama-2"
  264. server.n_ctx = 256
  265. server.n_batch = 32
  266. server.n_slots = 2
  267. server.n_predict = 64
  268. server.seed = 42
  269. return server
  270. @staticmethod
  271. def bert_bge_small() -> ServerProcess:
  272. server = ServerProcess()
  273. server.model_hf_repo = "ggml-org/models"
  274. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  275. server.model_alias = "bert-bge-small"
  276. server.n_ctx = 512
  277. server.n_batch = 128
  278. server.n_ubatch = 128
  279. server.n_slots = 2
  280. server.seed = 42
  281. server.server_embeddings = True
  282. return server
  283. @staticmethod
  284. def tinyllama_infill() -> ServerProcess:
  285. server = ServerProcess()
  286. server.model_hf_repo = "ggml-org/models"
  287. server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
  288. server.model_alias = "tinyllama-infill"
  289. server.n_ctx = 2048
  290. server.n_batch = 1024
  291. server.n_slots = 1
  292. server.n_predict = 64
  293. server.temperature = 0.0
  294. server.seed = 42
  295. return server
  296. @staticmethod
  297. def stories15m_moe() -> ServerProcess:
  298. server = ServerProcess()
  299. server.model_hf_repo = "ggml-org/stories15M_MOE"
  300. server.model_hf_file = "stories15M_MOE-F16.gguf"
  301. server.model_alias = "stories15m-moe"
  302. server.n_ctx = 2048
  303. server.n_batch = 1024
  304. server.n_slots = 1
  305. server.n_predict = 64
  306. server.temperature = 0.0
  307. server.seed = 42
  308. return server
  309. @staticmethod
  310. def jina_reranker_tiny() -> ServerProcess:
  311. server = ServerProcess()
  312. server.model_hf_repo = "ggml-org/models"
  313. server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
  314. server.model_alias = "jina-reranker"
  315. server.n_ctx = 512
  316. server.n_batch = 512
  317. server.n_slots = 1
  318. server.seed = 42
  319. server.server_reranking = True
  320. return server
  321. def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
  322. """
  323. Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
  324. Example usage:
  325. results = parallel_function_calls([
  326. (func1, (arg1, arg2)),
  327. (func2, (arg3, arg4)),
  328. ])
  329. """
  330. results = [None] * len(function_list)
  331. exceptions = []
  332. def worker(index, func, args):
  333. try:
  334. result = func(*args)
  335. results[index] = result
  336. except Exception as e:
  337. exceptions.append((index, str(e)))
  338. with ThreadPoolExecutor() as executor:
  339. futures = []
  340. for i, (func, args) in enumerate(function_list):
  341. future = executor.submit(worker, i, func, args)
  342. futures.append(future)
  343. # Wait for all futures to complete
  344. for future in as_completed(futures):
  345. pass
  346. # Check if there were any exceptions
  347. if exceptions:
  348. print("Exceptions occurred:")
  349. for index, error in exceptions:
  350. print(f"Function at index {index}: {error}")
  351. return results
  352. def match_regex(regex: str, text: str) -> bool:
  353. return (
  354. re.compile(
  355. regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
  356. ).search(text)
  357. is not None
  358. )
  359. def download_file(url: str, output_file_path: str | None = None) -> str:
  360. """
  361. Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
  362. output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
  363. Returns the local path of the downloaded file.
  364. """
  365. file_name = url.split('/').pop()
  366. output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
  367. if not os.path.exists(output_file):
  368. print(f"Downloading {url} to {output_file}")
  369. wget.download(url, out=output_file)
  370. print(f"Done downloading to {output_file}")
  371. else:
  372. print(f"File already exists at {output_file}")
  373. return output_file
  374. def is_slow_test_allowed():
  375. return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"