utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # type: ignore[reportUnusedImport]
  4. import subprocess
  5. import os
  6. import re
  7. import json
  8. import sys
  9. import requests
  10. import time
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from typing import (
  13. Any,
  14. Callable,
  15. ContextManager,
  16. Iterable,
  17. Iterator,
  18. List,
  19. Literal,
  20. Tuple,
  21. Set,
  22. )
  23. from re import RegexFlag
  24. import wget
  25. DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30
  26. class ServerResponse:
  27. headers: dict
  28. status_code: int
  29. body: dict | Any
  30. class ServerProcess:
  31. # default options
  32. debug: bool = False
  33. server_port: int = 8080
  34. server_host: str = "127.0.0.1"
  35. model_hf_repo: str = "ggml-org/models"
  36. model_hf_file: str = "tinyllamas/stories260K.gguf"
  37. model_alias: str = "tinyllama-2"
  38. temperature: float = 0.8
  39. seed: int = 42
  40. # custom options
  41. model_alias: str | None = None
  42. model_url: str | None = None
  43. model_file: str | None = None
  44. model_draft: str | None = None
  45. n_threads: int | None = None
  46. n_gpu_layer: int | None = None
  47. n_batch: int | None = None
  48. n_ubatch: int | None = None
  49. n_ctx: int | None = None
  50. n_ga: int | None = None
  51. n_ga_w: int | None = None
  52. n_predict: int | None = None
  53. n_prompts: int | None = 0
  54. slot_save_path: str | None = None
  55. id_slot: int | None = None
  56. cache_prompt: bool | None = None
  57. n_slots: int | None = None
  58. server_continuous_batching: bool | None = False
  59. server_embeddings: bool | None = False
  60. server_reranking: bool | None = False
  61. server_metrics: bool | None = False
  62. server_slots: bool | None = False
  63. pooling: str | None = None
  64. draft: int | None = None
  65. api_key: str | None = None
  66. response_format: str | None = None
  67. lora_files: List[str] | None = None
  68. disable_ctx_shift: int | None = False
  69. draft_min: int | None = None
  70. draft_max: int | None = None
  71. no_webui: bool | None = None
  72. chat_template: str | None = None
  73. # session variables
  74. process: subprocess.Popen | None = None
  75. def __init__(self):
  76. if "N_GPU_LAYERS" in os.environ:
  77. self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
  78. if "DEBUG" in os.environ:
  79. self.debug = True
  80. if "PORT" in os.environ:
  81. self.server_port = int(os.environ["PORT"])
  82. def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
  83. if "LLAMA_SERVER_BIN_PATH" in os.environ:
  84. server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
  85. elif os.name == "nt":
  86. server_path = "../../../build/bin/Release/llama-server.exe"
  87. else:
  88. server_path = "../../../build/bin/llama-server"
  89. server_args = [
  90. "--host",
  91. self.server_host,
  92. "--port",
  93. self.server_port,
  94. "--temp",
  95. self.temperature,
  96. "--seed",
  97. self.seed,
  98. ]
  99. if self.model_file:
  100. server_args.extend(["--model", self.model_file])
  101. if self.model_url:
  102. server_args.extend(["--model-url", self.model_url])
  103. if self.model_draft:
  104. server_args.extend(["--model-draft", self.model_draft])
  105. if self.model_hf_repo:
  106. server_args.extend(["--hf-repo", self.model_hf_repo])
  107. if self.model_hf_file:
  108. server_args.extend(["--hf-file", self.model_hf_file])
  109. if self.n_batch:
  110. server_args.extend(["--batch-size", self.n_batch])
  111. if self.n_ubatch:
  112. server_args.extend(["--ubatch-size", self.n_ubatch])
  113. if self.n_threads:
  114. server_args.extend(["--threads", self.n_threads])
  115. if self.n_gpu_layer:
  116. server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
  117. if self.draft is not None:
  118. server_args.extend(["--draft", self.draft])
  119. if self.server_continuous_batching:
  120. server_args.append("--cont-batching")
  121. if self.server_embeddings:
  122. server_args.append("--embedding")
  123. if self.server_reranking:
  124. server_args.append("--reranking")
  125. if self.server_metrics:
  126. server_args.append("--metrics")
  127. if self.server_slots:
  128. server_args.append("--slots")
  129. if self.pooling:
  130. server_args.extend(["--pooling", self.pooling])
  131. if self.model_alias:
  132. server_args.extend(["--alias", self.model_alias])
  133. if self.n_ctx:
  134. server_args.extend(["--ctx-size", self.n_ctx])
  135. if self.n_slots:
  136. server_args.extend(["--parallel", self.n_slots])
  137. if self.n_predict:
  138. server_args.extend(["--n-predict", self.n_predict])
  139. if self.slot_save_path:
  140. server_args.extend(["--slot-save-path", self.slot_save_path])
  141. if self.n_ga:
  142. server_args.extend(["--grp-attn-n", self.n_ga])
  143. if self.n_ga_w:
  144. server_args.extend(["--grp-attn-w", self.n_ga_w])
  145. if self.debug:
  146. server_args.append("--verbose")
  147. if self.lora_files:
  148. for lora_file in self.lora_files:
  149. server_args.extend(["--lora", lora_file])
  150. if self.disable_ctx_shift:
  151. server_args.extend(["--no-context-shift"])
  152. if self.api_key:
  153. server_args.extend(["--api-key", self.api_key])
  154. if self.draft_max:
  155. server_args.extend(["--draft-max", self.draft_max])
  156. if self.draft_min:
  157. server_args.extend(["--draft-min", self.draft_min])
  158. if self.no_webui:
  159. server_args.append("--no-webui")
  160. if self.chat_template:
  161. server_args.extend(["--chat-template", self.chat_template])
  162. args = [str(arg) for arg in [server_path, *server_args]]
  163. print(f"bench: starting server with: {' '.join(args)}")
  164. flags = 0
  165. if "nt" == os.name:
  166. flags |= subprocess.DETACHED_PROCESS
  167. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  168. flags |= subprocess.CREATE_NO_WINDOW
  169. self.process = subprocess.Popen(
  170. [str(arg) for arg in [server_path, *server_args]],
  171. creationflags=flags,
  172. stdout=sys.stdout,
  173. stderr=sys.stdout,
  174. env={**os.environ, "LLAMA_CACHE": "tmp"},
  175. )
  176. server_instances.add(self)
  177. print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
  178. # wait for server to start
  179. start_time = time.time()
  180. while time.time() - start_time < timeout_seconds:
  181. try:
  182. response = self.make_request("GET", "/health", headers={
  183. "Authorization": f"Bearer {self.api_key}" if self.api_key else None
  184. })
  185. if response.status_code == 200:
  186. self.ready = True
  187. return # server is ready
  188. except Exception as e:
  189. pass
  190. print(f"Waiting for server to start...")
  191. time.sleep(0.5)
  192. raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
  193. def stop(self) -> None:
  194. if self in server_instances:
  195. server_instances.remove(self)
  196. if self.process:
  197. print(f"Stopping server with pid={self.process.pid}")
  198. self.process.kill()
  199. self.process = None
  200. def make_request(
  201. self,
  202. method: str,
  203. path: str,
  204. data: dict | Any | None = None,
  205. headers: dict | None = None,
  206. timeout: float | None = None,
  207. ) -> ServerResponse:
  208. url = f"http://{self.server_host}:{self.server_port}{path}"
  209. parse_body = False
  210. if method == "GET":
  211. response = requests.get(url, headers=headers, timeout=timeout)
  212. parse_body = True
  213. elif method == "POST":
  214. response = requests.post(url, headers=headers, json=data, timeout=timeout)
  215. parse_body = True
  216. elif method == "OPTIONS":
  217. response = requests.options(url, headers=headers, timeout=timeout)
  218. else:
  219. raise ValueError(f"Unimplemented method: {method}")
  220. result = ServerResponse()
  221. result.headers = dict(response.headers)
  222. result.status_code = response.status_code
  223. result.body = response.json() if parse_body else None
  224. print("Response from server", json.dumps(result.body, indent=2))
  225. return result
  226. def make_stream_request(
  227. self,
  228. method: str,
  229. path: str,
  230. data: dict | None = None,
  231. headers: dict | None = None,
  232. ) -> Iterator[dict]:
  233. url = f"http://{self.server_host}:{self.server_port}{path}"
  234. if method == "POST":
  235. response = requests.post(url, headers=headers, json=data, stream=True)
  236. else:
  237. raise ValueError(f"Unimplemented method: {method}")
  238. for line_bytes in response.iter_lines():
  239. line = line_bytes.decode("utf-8")
  240. if '[DONE]' in line:
  241. break
  242. elif line.startswith('data: '):
  243. data = json.loads(line[6:])
  244. print("Partial response from server", json.dumps(data, indent=2))
  245. yield data
  246. server_instances: Set[ServerProcess] = set()
  247. class ServerPreset:
  248. @staticmethod
  249. def tinyllama2() -> ServerProcess:
  250. server = ServerProcess()
  251. server.model_hf_repo = "ggml-org/models"
  252. server.model_hf_file = "tinyllamas/stories260K.gguf"
  253. server.model_alias = "tinyllama-2"
  254. server.n_ctx = 256
  255. server.n_batch = 32
  256. server.n_slots = 2
  257. server.n_predict = 64
  258. server.seed = 42
  259. return server
  260. @staticmethod
  261. def bert_bge_small() -> ServerProcess:
  262. server = ServerProcess()
  263. server.model_hf_repo = "ggml-org/models"
  264. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  265. server.model_alias = "bert-bge-small"
  266. server.n_ctx = 512
  267. server.n_batch = 128
  268. server.n_ubatch = 128
  269. server.n_slots = 2
  270. server.seed = 42
  271. server.server_embeddings = True
  272. return server
  273. @staticmethod
  274. def tinyllama_infill() -> ServerProcess:
  275. server = ServerProcess()
  276. server.model_hf_repo = "ggml-org/models"
  277. server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
  278. server.model_alias = "tinyllama-infill"
  279. server.n_ctx = 2048
  280. server.n_batch = 1024
  281. server.n_slots = 1
  282. server.n_predict = 64
  283. server.temperature = 0.0
  284. server.seed = 42
  285. return server
  286. @staticmethod
  287. def stories15m_moe() -> ServerProcess:
  288. server = ServerProcess()
  289. server.model_hf_repo = "ggml-org/stories15M_MOE"
  290. server.model_hf_file = "stories15M_MOE-F16.gguf"
  291. server.model_alias = "stories15m-moe"
  292. server.n_ctx = 2048
  293. server.n_batch = 1024
  294. server.n_slots = 1
  295. server.n_predict = 64
  296. server.temperature = 0.0
  297. server.seed = 42
  298. return server
  299. @staticmethod
  300. def jina_reranker_tiny() -> ServerProcess:
  301. server = ServerProcess()
  302. server.model_hf_repo = "ggml-org/models"
  303. server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
  304. server.model_alias = "jina-reranker"
  305. server.n_ctx = 512
  306. server.n_batch = 512
  307. server.n_slots = 1
  308. server.seed = 42
  309. server.server_reranking = True
  310. return server
  311. def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
  312. """
  313. Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
  314. Example usage:
  315. results = parallel_function_calls([
  316. (func1, (arg1, arg2)),
  317. (func2, (arg3, arg4)),
  318. ])
  319. """
  320. results = [None] * len(function_list)
  321. exceptions = []
  322. def worker(index, func, args):
  323. try:
  324. result = func(*args)
  325. results[index] = result
  326. except Exception as e:
  327. exceptions.append((index, str(e)))
  328. with ThreadPoolExecutor() as executor:
  329. futures = []
  330. for i, (func, args) in enumerate(function_list):
  331. future = executor.submit(worker, i, func, args)
  332. futures.append(future)
  333. # Wait for all futures to complete
  334. for future in as_completed(futures):
  335. pass
  336. # Check if there were any exceptions
  337. if exceptions:
  338. print("Exceptions occurred:")
  339. for index, error in exceptions:
  340. print(f"Function at index {index}: {error}")
  341. return results
  342. def match_regex(regex: str, text: str) -> bool:
  343. return (
  344. re.compile(
  345. regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
  346. ).search(text)
  347. is not None
  348. )
  349. def download_file(url: str, output_file_path: str | None = None) -> str:
  350. """
  351. Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
  352. output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
  353. Returns the local path of the downloaded file.
  354. """
  355. file_name = url.split('/').pop()
  356. output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
  357. if not os.path.exists(output_file):
  358. print(f"Downloading {url} to {output_file}")
  359. wget.download(url, out=output_file)
  360. print(f"Done downloading to {output_file}")
  361. else:
  362. print(f"File already exists at {output_file}")
  363. return output_file
  364. def is_slow_test_allowed():
  365. return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"