utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # type: ignore[reportUnusedImport]
  4. import subprocess
  5. import os
  6. import re
  7. import json
  8. import sys
  9. import requests
  10. import time
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from typing import (
  13. Any,
  14. Callable,
  15. ContextManager,
  16. Iterable,
  17. Iterator,
  18. List,
  19. Literal,
  20. Tuple,
  21. Set,
  22. )
  23. from re import RegexFlag
  24. class ServerResponse:
  25. headers: dict
  26. status_code: int
  27. body: dict | Any
  28. class ServerProcess:
  29. # default options
  30. debug: bool = False
  31. server_port: int = 8080
  32. server_host: str = "127.0.0.1"
  33. model_hf_repo: str = "ggml-org/models"
  34. model_hf_file: str = "tinyllamas/stories260K.gguf"
  35. model_alias: str = "tinyllama-2"
  36. temperature: float = 0.8
  37. seed: int = 42
  38. # custom options
  39. model_alias: str | None = None
  40. model_url: str | None = None
  41. model_file: str | None = None
  42. model_draft: str | None = None
  43. n_threads: int | None = None
  44. n_gpu_layer: int | None = None
  45. n_batch: int | None = None
  46. n_ubatch: int | None = None
  47. n_ctx: int | None = None
  48. n_ga: int | None = None
  49. n_ga_w: int | None = None
  50. n_predict: int | None = None
  51. n_prompts: int | None = 0
  52. slot_save_path: str | None = None
  53. id_slot: int | None = None
  54. cache_prompt: bool | None = None
  55. n_slots: int | None = None
  56. server_continuous_batching: bool | None = False
  57. server_embeddings: bool | None = False
  58. server_reranking: bool | None = False
  59. server_metrics: bool | None = False
  60. server_slots: bool | None = False
  61. pooling: str | None = None
  62. draft: int | None = None
  63. api_key: str | None = None
  64. response_format: str | None = None
  65. lora_files: List[str] | None = None
  66. disable_ctx_shift: int | None = False
  67. draft_min: int | None = None
  68. draft_max: int | None = None
  69. no_webui: bool | None = None
  70. chat_template: str | None = None
  71. # session variables
  72. process: subprocess.Popen | None = None
  73. def __init__(self):
  74. if "N_GPU_LAYERS" in os.environ:
  75. self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
  76. if "DEBUG" in os.environ:
  77. self.debug = True
  78. if "PORT" in os.environ:
  79. self.server_port = int(os.environ["PORT"])
  80. def start(self, timeout_seconds: int = 10) -> None:
  81. if "LLAMA_SERVER_BIN_PATH" in os.environ:
  82. server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
  83. elif os.name == "nt":
  84. server_path = "../../../build/bin/Release/llama-server.exe"
  85. else:
  86. server_path = "../../../build/bin/llama-server"
  87. server_args = [
  88. "--host",
  89. self.server_host,
  90. "--port",
  91. self.server_port,
  92. "--temp",
  93. self.temperature,
  94. "--seed",
  95. self.seed,
  96. ]
  97. if self.model_file:
  98. server_args.extend(["--model", self.model_file])
  99. if self.model_url:
  100. server_args.extend(["--model-url", self.model_url])
  101. if self.model_draft:
  102. server_args.extend(["--model-draft", self.model_draft])
  103. if self.model_hf_repo:
  104. server_args.extend(["--hf-repo", self.model_hf_repo])
  105. if self.model_hf_file:
  106. server_args.extend(["--hf-file", self.model_hf_file])
  107. if self.n_batch:
  108. server_args.extend(["--batch-size", self.n_batch])
  109. if self.n_ubatch:
  110. server_args.extend(["--ubatch-size", self.n_ubatch])
  111. if self.n_threads:
  112. server_args.extend(["--threads", self.n_threads])
  113. if self.n_gpu_layer:
  114. server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
  115. if self.draft is not None:
  116. server_args.extend(["--draft", self.draft])
  117. if self.server_continuous_batching:
  118. server_args.append("--cont-batching")
  119. if self.server_embeddings:
  120. server_args.append("--embedding")
  121. if self.server_reranking:
  122. server_args.append("--reranking")
  123. if self.server_metrics:
  124. server_args.append("--metrics")
  125. if self.server_slots:
  126. server_args.append("--slots")
  127. if self.pooling:
  128. server_args.extend(["--pooling", self.pooling])
  129. if self.model_alias:
  130. server_args.extend(["--alias", self.model_alias])
  131. if self.n_ctx:
  132. server_args.extend(["--ctx-size", self.n_ctx])
  133. if self.n_slots:
  134. server_args.extend(["--parallel", self.n_slots])
  135. if self.n_predict:
  136. server_args.extend(["--n-predict", self.n_predict])
  137. if self.slot_save_path:
  138. server_args.extend(["--slot-save-path", self.slot_save_path])
  139. if self.n_ga:
  140. server_args.extend(["--grp-attn-n", self.n_ga])
  141. if self.n_ga_w:
  142. server_args.extend(["--grp-attn-w", self.n_ga_w])
  143. if self.debug:
  144. server_args.append("--verbose")
  145. if self.lora_files:
  146. for lora_file in self.lora_files:
  147. server_args.extend(["--lora", lora_file])
  148. if self.disable_ctx_shift:
  149. server_args.extend(["--no-context-shift"])
  150. if self.api_key:
  151. server_args.extend(["--api-key", self.api_key])
  152. if self.draft_max:
  153. server_args.extend(["--draft-max", self.draft_max])
  154. if self.draft_min:
  155. server_args.extend(["--draft-min", self.draft_min])
  156. if self.no_webui:
  157. server_args.append("--no-webui")
  158. if self.chat_template:
  159. server_args.extend(["--chat-template", self.chat_template])
  160. args = [str(arg) for arg in [server_path, *server_args]]
  161. print(f"bench: starting server with: {' '.join(args)}")
  162. flags = 0
  163. if "nt" == os.name:
  164. flags |= subprocess.DETACHED_PROCESS
  165. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  166. flags |= subprocess.CREATE_NO_WINDOW
  167. self.process = subprocess.Popen(
  168. [str(arg) for arg in [server_path, *server_args]],
  169. creationflags=flags,
  170. stdout=sys.stdout,
  171. stderr=sys.stdout,
  172. env={**os.environ, "LLAMA_CACHE": "tmp"},
  173. )
  174. server_instances.add(self)
  175. print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
  176. # wait for server to start
  177. start_time = time.time()
  178. while time.time() - start_time < timeout_seconds:
  179. try:
  180. response = self.make_request("GET", "/health", headers={
  181. "Authorization": f"Bearer {self.api_key}" if self.api_key else None
  182. })
  183. if response.status_code == 200:
  184. self.ready = True
  185. return # server is ready
  186. except Exception as e:
  187. pass
  188. print(f"Waiting for server to start...")
  189. time.sleep(0.5)
  190. raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
  191. def stop(self) -> None:
  192. if self in server_instances:
  193. server_instances.remove(self)
  194. if self.process:
  195. print(f"Stopping server with pid={self.process.pid}")
  196. self.process.kill()
  197. self.process = None
  198. def make_request(
  199. self,
  200. method: str,
  201. path: str,
  202. data: dict | Any | None = None,
  203. headers: dict | None = None,
  204. ) -> ServerResponse:
  205. url = f"http://{self.server_host}:{self.server_port}{path}"
  206. parse_body = False
  207. if method == "GET":
  208. response = requests.get(url, headers=headers)
  209. parse_body = True
  210. elif method == "POST":
  211. response = requests.post(url, headers=headers, json=data)
  212. parse_body = True
  213. elif method == "OPTIONS":
  214. response = requests.options(url, headers=headers)
  215. else:
  216. raise ValueError(f"Unimplemented method: {method}")
  217. result = ServerResponse()
  218. result.headers = dict(response.headers)
  219. result.status_code = response.status_code
  220. result.body = response.json() if parse_body else None
  221. print("Response from server", json.dumps(result.body, indent=2))
  222. return result
  223. def make_stream_request(
  224. self,
  225. method: str,
  226. path: str,
  227. data: dict | None = None,
  228. headers: dict | None = None,
  229. ) -> Iterator[dict]:
  230. url = f"http://{self.server_host}:{self.server_port}{path}"
  231. if method == "POST":
  232. response = requests.post(url, headers=headers, json=data, stream=True)
  233. else:
  234. raise ValueError(f"Unimplemented method: {method}")
  235. for line_bytes in response.iter_lines():
  236. line = line_bytes.decode("utf-8")
  237. if '[DONE]' in line:
  238. break
  239. elif line.startswith('data: '):
  240. data = json.loads(line[6:])
  241. print("Partial response from server", json.dumps(data, indent=2))
  242. yield data
  243. server_instances: Set[ServerProcess] = set()
  244. class ServerPreset:
  245. @staticmethod
  246. def tinyllama2() -> ServerProcess:
  247. server = ServerProcess()
  248. server.model_hf_repo = "ggml-org/models"
  249. server.model_hf_file = "tinyllamas/stories260K.gguf"
  250. server.model_alias = "tinyllama-2"
  251. server.n_ctx = 256
  252. server.n_batch = 32
  253. server.n_slots = 2
  254. server.n_predict = 64
  255. server.seed = 42
  256. return server
  257. @staticmethod
  258. def bert_bge_small() -> ServerProcess:
  259. server = ServerProcess()
  260. server.model_hf_repo = "ggml-org/models"
  261. server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
  262. server.model_alias = "bert-bge-small"
  263. server.n_ctx = 512
  264. server.n_batch = 128
  265. server.n_ubatch = 128
  266. server.n_slots = 2
  267. server.seed = 42
  268. server.server_embeddings = True
  269. return server
  270. @staticmethod
  271. def tinyllama_infill() -> ServerProcess:
  272. server = ServerProcess()
  273. server.model_hf_repo = "ggml-org/models"
  274. server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
  275. server.model_alias = "tinyllama-infill"
  276. server.n_ctx = 2048
  277. server.n_batch = 1024
  278. server.n_slots = 1
  279. server.n_predict = 64
  280. server.temperature = 0.0
  281. server.seed = 42
  282. return server
  283. @staticmethod
  284. def stories15m_moe() -> ServerProcess:
  285. server = ServerProcess()
  286. server.model_hf_repo = "ggml-org/stories15M_MOE"
  287. server.model_hf_file = "stories15M_MOE-F16.gguf"
  288. server.model_alias = "stories15m-moe"
  289. server.n_ctx = 2048
  290. server.n_batch = 1024
  291. server.n_slots = 1
  292. server.n_predict = 64
  293. server.temperature = 0.0
  294. server.seed = 42
  295. return server
  296. @staticmethod
  297. def jina_reranker_tiny() -> ServerProcess:
  298. server = ServerProcess()
  299. server.model_hf_repo = "ggml-org/models"
  300. server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
  301. server.model_alias = "jina-reranker"
  302. server.n_ctx = 512
  303. server.n_batch = 512
  304. server.n_slots = 1
  305. server.seed = 42
  306. server.server_reranking = True
  307. return server
  308. def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
  309. """
  310. Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
  311. Example usage:
  312. results = parallel_function_calls([
  313. (func1, (arg1, arg2)),
  314. (func2, (arg3, arg4)),
  315. ])
  316. """
  317. results = [None] * len(function_list)
  318. exceptions = []
  319. def worker(index, func, args):
  320. try:
  321. result = func(*args)
  322. results[index] = result
  323. except Exception as e:
  324. exceptions.append((index, str(e)))
  325. with ThreadPoolExecutor() as executor:
  326. futures = []
  327. for i, (func, args) in enumerate(function_list):
  328. future = executor.submit(worker, i, func, args)
  329. futures.append(future)
  330. # Wait for all futures to complete
  331. for future in as_completed(futures):
  332. pass
  333. # Check if there were any exceptions
  334. if exceptions:
  335. print("Exceptions occurred:")
  336. for index, error in exceptions:
  337. print(f"Function at index {index}: {error}")
  338. return results
  339. def match_regex(regex: str, text: str) -> bool:
  340. return (
  341. re.compile(
  342. regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
  343. ).search(text)
  344. is not None
  345. )
  346. def is_slow_test_allowed():
  347. return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"