steps.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import asyncio
  4. import json
  5. import os
  6. import re
  7. import socket
  8. import subprocess
  9. import sys
  10. import threading
  11. import time
  12. import requests
  13. from collections.abc import Sequence
  14. from contextlib import closing
  15. from re import RegexFlag
  16. from typing import Any, Literal, cast
  17. import aiohttp
  18. import numpy as np
  19. import openai
  20. from openai.types.chat import ChatCompletionChunk
  21. from behave import step # pyright: ignore[reportAttributeAccessIssue]
  22. from behave.api.async_step import async_run_until_complete
  23. from prometheus_client import parser
  24. # pyright: reportRedeclaration=false
  25. DEFAULT_TIMEOUT_SECONDS = aiohttp.ClientTimeout(total=600)
  26. @step("a server listening on {server_fqdn}:{server_port}")
  27. def step_server_config(context, server_fqdn: str, server_port: str):
  28. context.server_fqdn = server_fqdn
  29. context.server_port = int(server_port)
  30. context.n_threads = None
  31. context.n_gpu_layer = None
  32. if 'PORT' in os.environ:
  33. context.server_port = int(os.environ['PORT'])
  34. print(f"$PORT set, overriding server port with to {context.server_port}")
  35. if 'FQDN' in os.environ:
  36. context.server_fqdn = os.environ['FQDN']
  37. print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
  38. if 'N_GPU_LAYERS' in os.environ:
  39. context.n_gpu_layer = int(os.environ['N_GPU_LAYERS'])
  40. print(f"$N_GPU_LAYERS set, overriding n_gpu_layer with to {context.n_gpu_layer}")
  41. context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
  42. context.model_alias = None
  43. context.model_file = None
  44. context.model_hf_repo = None
  45. context.model_hf_file = None
  46. context.model_url = None
  47. context.n_batch = None
  48. context.n_ubatch = None
  49. context.n_ctx = None
  50. context.n_ga = None
  51. context.n_ga_w = None
  52. context.n_predict = None
  53. context.n_prompts = 0
  54. context.n_server_predict = None
  55. context.slot_save_path = None
  56. context.id_slot = None
  57. context.cache_prompt = None
  58. context.n_slots = None
  59. context.prompt_prefix = None
  60. context.prompt_suffix = None
  61. context.server_api_key = None
  62. context.server_continuous_batching = False
  63. context.server_embeddings = False
  64. context.server_reranking = False
  65. context.server_metrics = False
  66. context.server_process = None
  67. context.seed = None
  68. context.draft = None
  69. context.server_seed = None
  70. context.user_api_key = None
  71. context.response_format = None
  72. context.temperature = None
  73. context.lora_file = None
  74. context.disable_ctx_shift = False
  75. # infill
  76. context.infill_input_extra = None
  77. context.infill_input_suffix = ''
  78. context.infill_input_prefix = ''
  79. context.tasks_result = []
  80. context.concurrent_tasks = []
  81. context.prompts = []
  82. context.reranking_query = None
  83. context.reranking_documents = []
  84. context.reranking_results = None
  85. @step('a model file {hf_file} from HF repo {hf_repo}')
  86. def step_download_hf_model(context, hf_file: str, hf_repo: str):
  87. context.model_hf_repo = hf_repo
  88. context.model_hf_file = hf_file
  89. context.model_file = os.path.basename(hf_file)
  90. @step('a lora adapter file from {lora_file_url}')
  91. def step_download_lora_file(context, lora_file_url: str):
  92. file_name = lora_file_url.split('/').pop()
  93. context.lora_file = f'../../../{file_name}'
  94. with open(context.lora_file, 'wb') as f:
  95. f.write(requests.get(lora_file_url).content)
  96. @step('a model file {model_file}')
  97. def step_model_file(context, model_file: str):
  98. context.model_file = model_file
  99. @step('a model url {model_url}')
  100. def step_model_url(context, model_url: str):
  101. context.model_url = model_url
  102. @step('a model alias {model_alias}')
  103. def step_model_alias(context, model_alias: str):
  104. context.model_alias = model_alias
  105. @step('{seed:d} as server seed')
  106. def step_seed(context, seed: int):
  107. context.server_seed = seed
  108. @step('{ngl:d} GPU offloaded layers')
  109. def step_n_gpu_layer(context, ngl: int):
  110. if 'N_GPU_LAYERS' in os.environ:
  111. new_ngl = int(os.environ['N_GPU_LAYERS'])
  112. if context.debug:
  113. print(f"-ngl upgraded from {ngl} to {new_ngl}")
  114. ngl = new_ngl
  115. context.n_gpu_layer = ngl
  116. @step('{n_threads:d} threads')
  117. def step_n_threads(context, n_threads: int):
  118. context.n_thread = n_threads
  119. @step('{draft:d} as draft')
  120. def step_draft(context, draft: int):
  121. context.draft = draft
  122. @step('{n_ctx:d} KV cache size')
  123. def step_n_ctx(context, n_ctx: int):
  124. context.n_ctx = n_ctx
  125. @step('{n_slots:d} slots')
  126. def step_n_slots(context, n_slots: int):
  127. context.n_slots = n_slots
  128. @step('{n_predict:d} server max tokens to predict')
  129. def step_server_n_predict(context, n_predict: int):
  130. context.n_server_predict = n_predict if n_predict > 0 else None
  131. @step('{slot_save_path} as slot save path')
  132. def step_slot_save_path(context, slot_save_path: str):
  133. context.slot_save_path = slot_save_path
  134. @step('using slot id {id_slot:d}')
  135. def step_id_slot(context, id_slot: int):
  136. context.id_slot = id_slot
  137. @step('prompt caching is enabled')
  138. def step_enable_prompt_cache(context):
  139. context.cache_prompt = True
  140. @step('continuous batching')
  141. def step_server_continuous_batching(context):
  142. context.server_continuous_batching = True
  143. @step('enable embeddings endpoint')
  144. def step_server_embeddings(context):
  145. context.server_embeddings = True
  146. @step('enable reranking endpoint')
  147. def step_server_reranking(context):
  148. context.server_reranking = True
  149. @step('prometheus compatible metrics exposed')
  150. def step_server_metrics(context):
  151. context.server_metrics = True
  152. @step('disable context shifting')
  153. def step_server_disable_ctx_shift(context):
  154. context.disable_ctx_shift = True
  155. @step("the server is starting")
  156. def step_start_server(context):
  157. start_server_background(context)
  158. attempts = 0
  159. max_attempts = 20
  160. if 'GITHUB_ACTIONS' in os.environ:
  161. max_attempts *= 2
  162. addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
  163. family, typ, proto, _, sockaddr = addrs[0]
  164. while True:
  165. with closing(socket.socket(family, typ, proto)) as sock:
  166. result = sock.connect_ex(sockaddr)
  167. if result == 0:
  168. print("\x1b[33;46mserver started!\x1b[0m")
  169. return
  170. attempts += 1
  171. if attempts > max_attempts:
  172. assert False, "server not started"
  173. print(f"waiting for server to start, connect error code = {result}...")
  174. time.sleep(0.1)
  175. async def wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int):
  176. match expecting_status:
  177. case 'healthy':
  178. await wait_for_slots_status(context, context.base_url, 200,
  179. timeout=timeout)
  180. case 'ready' | 'idle':
  181. await wait_for_slots_status(context, context.base_url, 200,
  182. timeout=timeout,
  183. params={'fail_on_no_slot': 1},
  184. slots_idle=context.n_slots,
  185. slots_processing=0)
  186. case 'busy':
  187. await wait_for_slots_status(context, context.base_url, 503,
  188. params={'fail_on_no_slot': 1},
  189. slots_idle=0,
  190. slots_processing=context.n_slots)
  191. case _:
  192. assert False, "unknown status"
  193. @step("the server is {expecting_status} with timeout {timeout:d} seconds")
  194. @async_run_until_complete
  195. async def step_wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int):
  196. await wait_for_server_status_with_timeout(context, expecting_status, timeout)
  197. @step("the server is {expecting_status}")
  198. @async_run_until_complete
  199. async def step_wait_for_server_status(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
  200. await wait_for_server_status_with_timeout(context, expecting_status, 30)
  201. @step('all slots are {expected_slot_status_string}')
  202. @async_run_until_complete
  203. async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
  204. match expected_slot_status_string:
  205. case 'idle':
  206. expected_slot_status = False
  207. case 'busy':
  208. expected_slot_status = True
  209. case _:
  210. assert False, "unknown status"
  211. expected_slots = [{'id': slot_id, 'is_processing': expected_slot_status}
  212. for slot_id in range(context.n_slots)]
  213. await request_slots_status(context, expected_slots)
  214. @step('a completion request with {api_error} api error')
  215. @async_run_until_complete
  216. async def step_request_completion(context, api_error: Literal['raised'] | str):
  217. expect_api_error = api_error == 'raised' or api_error != 'no'
  218. seeds = await completions_seed(context, num_seeds=1)
  219. completion = await request_completion(context.prompts.pop(),
  220. seeds[0] if seeds is not None else seeds,
  221. context.base_url,
  222. debug=context.debug,
  223. n_predict=context.n_predict,
  224. cache_prompt=context.cache_prompt,
  225. id_slot=context.id_slot,
  226. expect_api_error=expect_api_error,
  227. user_api_key=context.user_api_key,
  228. temperature=context.temperature)
  229. context.tasks_result.append(completion)
  230. if context.debug:
  231. print(f"Completion response: {completion}")
  232. if api_error == 'raised':
  233. assert completion == 401, f"completion must be an 401 status code: {completion}"
  234. elif api_error.isdigit():
  235. api_error_code = int(api_error)
  236. assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
  237. @step('an infill request with {api_error} api error')
  238. @async_run_until_complete
  239. async def step_request_completion(context, api_error: Literal['raised'] | str):
  240. if api_error != 'no':
  241. raise ValueError(f'api_error={api_error} is not yet implemented')
  242. payload = {
  243. "prompt": context.prompts[0],
  244. "input_suffix": context.infill_input_suffix,
  245. "input_prefix": context.infill_input_prefix,
  246. "n_predict": context.n_predict,
  247. "seed": context.seed,
  248. "temperature": context.temperature,
  249. }
  250. if context.infill_input_extra is not None:
  251. payload['input_extra'] = context.infill_input_extra
  252. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  253. async with session.post(f'{context.base_url}/infill',
  254. json=payload) as response:
  255. assert response.status == 200
  256. context.tasks_result = [await response.json()]
  257. @step('{predicted_n:d} tokens are predicted matching {re_content}')
  258. def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
  259. context.completion = context.tasks_result.pop()
  260. assert_n_tokens_predicted(context.completion, predicted_n, re_content)
  261. @step('{predicted_n:d} tokens are predicted')
  262. def step_n_tokens_predicted(context, predicted_n):
  263. context.completion = context.tasks_result.pop()
  264. assert_n_tokens_predicted(context.completion, predicted_n)
  265. @step('all predictions are equal')
  266. @async_run_until_complete
  267. async def step_predictions_equal(context):
  268. n_completions = await gather_tasks_results(context)
  269. assert n_completions >= 2, "need at least 2 completions"
  270. assert_all_predictions_equal(context.tasks_result)
  271. context.tasks_result = []
  272. @step('all predictions are different')
  273. @async_run_until_complete
  274. async def step_predictions_different(context):
  275. n_completions = await gather_tasks_results(context)
  276. assert n_completions >= 2, "need at least 2 completions"
  277. assert_all_predictions_different(context.tasks_result)
  278. context.tasks_result = []
  279. @step('all token probabilities are equal')
  280. @async_run_until_complete
  281. async def step_token_probabilities_equal(context):
  282. n_completions = await gather_tasks_results(context)
  283. assert n_completions >= 2, "need at least 2 completions"
  284. assert_all_token_probabilities_equal(context.tasks_result)
  285. context.tasks_result = []
  286. @step('the completion is truncated')
  287. def step_assert_completion_truncated(context):
  288. step_assert_completion_truncated(context, '')
  289. @step('the completion is {truncated} truncated')
  290. def step_assert_completion_truncated(context, truncated):
  291. truncated = truncated != "not"
  292. assert context.completion['truncated'] == truncated, f'{context.completion}'
  293. @step('{n_prompt:d} prompt tokens are processed')
  294. def step_impl(context, n_prompt):
  295. assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
  296. @step('a user prompt {user_prompt}')
  297. def step_user_prompt(context, user_prompt):
  298. context.prompts.append(user_prompt)
  299. context.n_prompts = len(context.prompts)
  300. @step('a system prompt {system_prompt}')
  301. def step_system_prompt(context, system_prompt):
  302. context.system_prompt = system_prompt
  303. @step('a model {model}')
  304. def step_model(context, model):
  305. context.model = model
  306. @step('{max_tokens:d} max tokens to predict')
  307. def step_max_tokens(context, max_tokens):
  308. context.n_predict = max_tokens
  309. @step('a response format {response_format}')
  310. def step_response_format(context, response_format):
  311. context.response_format = json.loads(response_format)
  312. @step('{temperature:f} temperature')
  313. def step_temperature(context, temperature):
  314. context.temperature = temperature
  315. @step('streaming is {enable_streaming}')
  316. def step_streaming(context, enable_streaming):
  317. context.enable_streaming = enable_streaming == 'enabled'
  318. @step('a user api key {user_api_key}')
  319. def step_user_api_key(context, user_api_key):
  320. context.user_api_key = user_api_key
  321. @step('no user api key')
  322. def step_no_user_api_key(context):
  323. context.user_api_key = None
  324. @step('a user api key ')
  325. def step_no_user_api_key_space(context):
  326. context.user_api_key = None
  327. @step('a server api key {server_api_key}')
  328. def step_server_api_key(context, server_api_key):
  329. context.server_api_key = server_api_key
  330. @step('{n_junk:d} as number of junk')
  331. def step_n_junk(context, n_junk):
  332. context.n_junk = n_junk
  333. @step('{n_batch:d} as batch size')
  334. def step_n_batch(context, n_batch):
  335. context.n_batch = n_batch
  336. @step('{n_ubatch:d} as ubatch size')
  337. def step_n_ubatch(context, n_ubatch):
  338. context.n_ubatch = n_ubatch
  339. @step('{seed:d} as seed')
  340. def step_seed(context, seed):
  341. if context.seed is None:
  342. context.seed = [seed]
  343. else:
  344. context.seed.append(seed)
  345. @step('BOS token is {bos:d}')
  346. def step_bos_token(context, bos):
  347. context.bos = bos
  348. @step('a prefix prompt')
  349. def step_prompt_prefix(context):
  350. context.prompt_prefix = context_text(context)
  351. @step('a junk suffix prompt')
  352. def step_prompt_junk_suffix(context):
  353. context.prompt_junk_suffix = context_text(context)
  354. @step('a suffix prompt')
  355. def step_prompt_suffix(context):
  356. context.prompt_suffix = context_text(context)
  357. @step('{n_ga:d} group attention factor'
  358. ' to extend context size through self-extend')
  359. def step_impl(context, n_ga):
  360. context.n_ga = n_ga
  361. @step('{n_ga_w:d} group attention width to extend context size through self-extend')
  362. def step_impl(context, n_ga_w):
  363. context.n_ga_w = n_ga_w
  364. @step('a passkey prompt template')
  365. def step_prompt_passkey(context):
  366. context.prompt_passkey = context_text(context)
  367. @step('a rerank query')
  368. def step_set_rerank_query(context):
  369. context.reranking_query = context_text(context)
  370. context.reranking_documents = []
  371. @step('a rerank document')
  372. def step_set_rerank_document(context):
  373. context.reranking_documents.append(context_text(context))
  374. @step('{n_prompts:d} fixed prompts')
  375. def step_fixed_prompts(context, n_prompts):
  376. context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
  377. context.n_prompts = n_prompts
  378. @step('a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
  379. def step_prompt_passkey(context, passkey, i_pos):
  380. prompt = ""
  381. for i in range(context.n_junk):
  382. if i % context.n_junk == i_pos:
  383. prompt += context.prompt_passkey # the passkey is already substituted
  384. prompt += context.prompt_junk_suffix
  385. if context.debug:
  386. passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
  387. print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
  388. context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
  389. context.n_prompts = len(context.prompts)
  390. @step('an OAI compatible chat completions request with {api_error} api error')
  391. @async_run_until_complete
  392. async def step_oai_chat_completions(context, api_error):
  393. if context.debug:
  394. print(f"Submitting OAI compatible completions request...")
  395. expect_api_error = api_error == 'raised'
  396. seeds = await completions_seed(context, num_seeds=1),
  397. completion = await oai_chat_completions(context.prompts.pop(),
  398. seeds[0] if seeds is not None else seeds,
  399. context.system_prompt,
  400. context.base_url,
  401. '/v1/chat',
  402. False,
  403. model=context.model if hasattr(context, 'model') else None,
  404. n_predict=context.n_predict
  405. if hasattr(context, 'n_predict') else None,
  406. enable_streaming=context.enable_streaming
  407. if hasattr(context, 'enable_streaming') else None,
  408. response_format=context.response_format
  409. if hasattr(context, 'response_format') else None,
  410. user_api_key=context.user_api_key
  411. if hasattr(context, 'user_api_key') else None,
  412. expect_api_error=expect_api_error)
  413. context.tasks_result.append(completion)
  414. if context.debug:
  415. print(f"Completion response: {completion}")
  416. if expect_api_error:
  417. assert completion == 401, f"completion must be an 401 status code: {completion}"
  418. if context.debug:
  419. print(f"Completion response: {completion}")
  420. @step('a prompt')
  421. def step_a_prompt(context):
  422. context.prompts.append(context_text(context))
  423. context.n_prompts = len(context.prompts)
  424. @step('a prompt {prompt}')
  425. def step_a_prompt_prompt(context, prompt):
  426. context.prompts.append(prompt)
  427. context.n_prompts = len(context.prompts)
  428. # TODO: allow this to be repeated
  429. @step('an infill input extra {filename} {text}')
  430. def step_infill_input_extra(context, filename, text):
  431. if filename == 'none':
  432. context.infill_input_extra = None
  433. else:
  434. context.infill_input_extra = [{'filename': filename, 'text': text}]
  435. @step('an infill input suffix {text}')
  436. def step_infill_input_suffix(context, text):
  437. context.infill_input_suffix = text
  438. @step('an infill input prefix {text}')
  439. def step_infill_input_prefix(context, text):
  440. context.infill_input_prefix = text
  441. @step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
  442. def step_many_prompts(context, num_prompts, prompt, seed):
  443. if context.seed is None:
  444. context.seed = []
  445. for _ in range(num_prompts):
  446. context.seed.append(seed)
  447. context.prompts.append(prompt)
  448. context.n_prompts = len(context.prompts)
  449. @step('concurrent completion requests')
  450. @async_run_until_complete()
  451. async def step_concurrent_completion_requests(context):
  452. await concurrent_requests(
  453. context,
  454. request_completion,
  455. # prompt is inserted automatically
  456. context.base_url,
  457. debug=context.debug,
  458. prompt_prefix=context.prompt_prefix,
  459. prompt_suffix=context.prompt_suffix,
  460. n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
  461. user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None,
  462. temperature=context.temperature,
  463. )
  464. @step('concurrent OAI completions requests')
  465. @async_run_until_complete
  466. async def step_oai_chat_completions(context):
  467. await concurrent_requests(context, oai_chat_completions,
  468. # user_prompt is inserted automatically
  469. context.system_prompt,
  470. context.base_url,
  471. '/v1/chat/completions',
  472. True, # async_client
  473. model=context.model
  474. if hasattr(context, 'model') else None,
  475. n_predict=context.n_predict
  476. if hasattr(context, 'n_predict') else None,
  477. enable_streaming=context.enable_streaming
  478. if hasattr(context, 'enable_streaming') else None,
  479. response_format=context.response_format
  480. if hasattr(context, 'response_format') else None,
  481. user_api_key=context.user_api_key
  482. if hasattr(context, 'user_api_key') else None)
  483. @step('concurrent OAI completions requests no v1')
  484. @async_run_until_complete
  485. async def step_oai_chat_completions(context):
  486. await concurrent_requests(context, oai_chat_completions,
  487. # user_prompt is inserted automatically
  488. context.system_prompt,
  489. context.base_url,
  490. '/chat/completions',
  491. True, # async_client
  492. model=context.model
  493. if hasattr(context, 'model') else None,
  494. n_predict=context.n_predict
  495. if hasattr(context, 'n_predict') else None,
  496. enable_streaming=context.enable_streaming
  497. if hasattr(context, 'enable_streaming') else None,
  498. response_format=context.response_format
  499. if hasattr(context, 'response_format') else None,
  500. user_api_key=context.user_api_key
  501. if hasattr(context, 'user_api_key') else None)
  502. @step('all prompts are predicted')
  503. @async_run_until_complete
  504. async def step_all_prompts_are_predicted(context):
  505. await all_prompts_are_predicted(context)
  506. @step('all prompts are predicted with {n_expected_predicted:d} tokens')
  507. @async_run_until_complete
  508. async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
  509. await all_prompts_are_predicted(context, n_expected_predicted)
  510. async def all_prompts_are_predicted(context, expected_predicted_n=None):
  511. n_completions = await gather_tasks_results(context)
  512. assert n_completions > 0
  513. for i in range(n_completions):
  514. assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
  515. assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
  516. @step('embeddings are computed for')
  517. @async_run_until_complete
  518. async def step_compute_embedding(context):
  519. context.n_prompts = 1
  520. context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
  521. @step('reranking request')
  522. @async_run_until_complete
  523. async def step_compute_reranking(context):
  524. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  525. async with session.post(f'{context.base_url}/reranking',
  526. json={
  527. "query": context.reranking_query,
  528. "documents": context.reranking_documents,
  529. }) as response:
  530. if response.status == 200:
  531. response_json = await response.json()
  532. context.reranking_results = response_json['results']
  533. else:
  534. context.reranking_results = response.status
  535. @step('all embeddings are the same')
  536. @async_run_until_complete
  537. async def step_all_embeddings_are_the_same(context):
  538. n_embedding_requests = await gather_tasks_results(context)
  539. assert n_embedding_requests > 0
  540. embeddings = []
  541. for i in range(n_embedding_requests):
  542. embedding = context.tasks_result.pop().pop()
  543. embeddings.append(embedding)
  544. assert_embeddings(embedding)
  545. n = len(embeddings)
  546. for i in range(n-1):
  547. for j in range(i+1, n):
  548. embedding1 = np.array(embeddings[i])
  549. embedding2 = np.array(embeddings[j])
  550. if context.debug:
  551. print(f"embedding1: {embedding1[-8:]}")
  552. print(f"embedding2: {embedding2[-8:]}")
  553. similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
  554. msg = f"Similarity between {i} and {j}: {similarity:.10f}"
  555. if context.debug:
  556. print(f"{msg}")
  557. assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
  558. @step('embeddings are generated')
  559. def step_assert_embeddings(context):
  560. assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
  561. f"context.n_prompts={context.n_prompts}\n"
  562. f"context.embeddings={context.embeddings}")
  563. for embedding in context.embeddings:
  564. assert_embeddings(embedding)
  565. @step('embeddings request with {api_error_code:d} api error')
  566. def step_assert_embeddings(context, api_error_code: int):
  567. assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
  568. @step('an OAI compatible embeddings computation request for')
  569. @async_run_until_complete
  570. async def step_oai_compute_embeddings(context):
  571. context.n_prompts = 1
  572. context.embeddings = await request_oai_embeddings(context_text(context), None,
  573. base_url=context.base_url,
  574. user_api_key=context.user_api_key,
  575. model=context.model)
  576. @step('an OAI compatible embeddings computation request for multiple inputs')
  577. @async_run_until_complete
  578. async def step_oai_compute_embeddings_multiple_inputs(context):
  579. context.embeddings = await request_oai_embeddings(context.prompts, None,
  580. base_url=context.base_url,
  581. user_api_key=context.user_api_key,
  582. model=context.model)
  583. context.prompts.clear()
  584. @step('concurrent embedding requests')
  585. @async_run_until_complete()
  586. async def step_concurrent_embedding_requests(context):
  587. await concurrent_requests(context,
  588. request_embedding,
  589. # prompt is inserted automatically
  590. base_url=context.base_url)
  591. @step('concurrent OAI embedding requests')
  592. @async_run_until_complete()
  593. async def step_concurrent_oai_embedding_requests(context):
  594. await concurrent_requests(context,
  595. request_oai_embeddings,
  596. # prompt is inserted automatically
  597. base_url=context.base_url,
  598. async_client=True,
  599. model=context.model)
  600. @step('all embeddings are generated')
  601. @async_run_until_complete()
  602. async def all_embeddings_are_generated(context):
  603. n_embedding_requests = await gather_tasks_results(context)
  604. assert n_embedding_requests == context.n_prompts
  605. for i in range(n_embedding_requests):
  606. assert_embeddings(context.tasks_result.pop().pop())
  607. @step('reranking results are returned')
  608. def reranking_results_are_returned(context):
  609. assert len(context.reranking_results) == len(context.reranking_documents)
  610. @step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
  611. def reranking_results_are_returned(context, idx_high: int, idx_low: int):
  612. max_score, max_idx = 0, 0
  613. min_score, min_idx = 0, 0
  614. for res in context.reranking_results:
  615. if max_score < res['relevance_score']:
  616. max_score = res['relevance_score']
  617. max_idx = res['index']
  618. if min_score > res['relevance_score']:
  619. min_score = res['relevance_score']
  620. min_idx = res['index']
  621. print(context.reranking_results)
  622. assert max_idx == idx_high
  623. assert min_idx == idx_low
  624. @step('adding special tokens')
  625. def step_tokenize_set_add_special(context):
  626. context.tokenize_add_special = True
  627. @step("tokenizing with pieces")
  628. @async_run_until_complete
  629. async def step_tokenize_with_pieces(context):
  630. context.tokenized_text = context_text(context)
  631. async with aiohttp.ClientSession() as session:
  632. tokenize_args = {"content": context.tokenized_text, "with_pieces": True}
  633. if getattr(context, "tokenize_add_special", None) is not None:
  634. tokenize_args["add_special"] = context.tokenize_add_special
  635. async with session.post(
  636. f"{context.base_url}/tokenize", json=tokenize_args
  637. ) as response:
  638. assert response.status == 200
  639. tokenize_json = await response.json()
  640. context.tokens_with_pieces = tokenize_json["tokens"]
  641. @step("tokens are given with pieces")
  642. @async_run_until_complete
  643. async def step_tokenize_with_pieces(context):
  644. # Verify that the response contains both token IDs and pieces
  645. assert all(
  646. "id" in token and "piece" in token for token in context.tokens_with_pieces
  647. )
  648. @step('tokenizing')
  649. @async_run_until_complete
  650. async def step_tokenize(context):
  651. context.tokenized_text = context_text(context)
  652. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  653. tokenize_args = {
  654. "content": context.tokenized_text,
  655. }
  656. if getattr(context, 'tokenize_add_special', None) is not None:
  657. tokenize_args['add_special'] = context.tokenize_add_special
  658. async with session.post(f'{context.base_url}/tokenize',
  659. json=tokenize_args) as response:
  660. assert response.status == 200
  661. tokenize_json = await response.json()
  662. context.tokens = tokenize_json['tokens']
  663. @step('tokens can be detokenized')
  664. @async_run_until_complete
  665. async def step_detokenize(context):
  666. assert len(context.tokens) > 0
  667. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  668. async with session.post(f'{context.base_url}/detokenize',
  669. json={
  670. "tokens": context.tokens,
  671. }) as response:
  672. assert response.status == 200
  673. detokenize_json = await response.json()
  674. # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15
  675. assert context.tokenized_text == detokenize_json['content'].strip()
  676. @step('tokens begin with BOS')
  677. def step_strings_for_tokenization(context):
  678. assert context.tokens[0] == context.bos
  679. @step('tokens do not begin with BOS')
  680. def step_strings_for_tokenization(context):
  681. assert context.tokens[0] != context.bos
  682. @step('first token is removed')
  683. def step_strings_for_tokenization(context):
  684. context.tokens = context.tokens[1:]
  685. @step('an OPTIONS request is sent from {origin}')
  686. @async_run_until_complete
  687. async def step_options_request(context, origin):
  688. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  689. headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin}
  690. async with session.options(f'{context.base_url}/v1/chat/completions',
  691. headers=headers) as response:
  692. assert response.status == 200
  693. context.options_response = response
  694. @step('CORS header {cors_header} is set to {cors_header_value}')
  695. def step_check_options_header_value(context, cors_header, cors_header_value):
  696. assert context.options_response.headers[cors_header] == cors_header_value
  697. @step('prometheus metrics are exposed')
  698. @async_run_until_complete
  699. async def step_prometheus_metrics_exported(context):
  700. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  701. async with await session.get(f'{context.base_url}/metrics') as metrics_response:
  702. assert metrics_response.status == 200
  703. assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
  704. metrics_raw = await metrics_response.text()
  705. metric_exported = False
  706. if context.debug:
  707. print(f"/metrics answer:\n{metrics_raw}")
  708. context.metrics = {}
  709. for metric in parser.text_string_to_metric_families(metrics_raw):
  710. match metric.name:
  711. case "llamacpp:kv_cache_usage_ratio":
  712. assert len(metric.samples) > 0
  713. metric_exported = True
  714. context.metrics[metric.name] = metric
  715. assert int(metrics_response.headers["Process-Start-Time-Unix"]) > 0, "no header process start time"
  716. assert metric_exported, "No metrics exported"
  717. @step('metric {metric_name} is {metric_value:d}')
  718. def step_assert_metric_value(context, metric_name, metric_value):
  719. if metric_name not in context.metrics:
  720. assert False, f"no metric {metric_name} in {context.metrics.keys()}"
  721. assert context.metrics[metric_name].samples[0].value == metric_value, f"metric: {context.metrics[metric_name]}"
  722. @step('available models')
  723. def step_available_models(context):
  724. # openai client always expects an api_key
  725. openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
  726. openai.base_url = f'{context.base_url}/v1/'
  727. context.models = openai.models.list().data
  728. @step('{n_model:d} models are supported')
  729. def step_supported_models(context, n_model):
  730. if context.debug:
  731. print("server models available:", context.models)
  732. assert len(context.models) == n_model
  733. @step('model {i_model:d} is {param} {preposition} {param_value}')
  734. def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str):
  735. assert i_model < len(context.models)
  736. model = context.models[i_model]
  737. param_value = param_value.split(' ', 1)[0]
  738. match param:
  739. case 'identified':
  740. value = model.id
  741. case 'trained':
  742. value = str(model.meta["n_ctx_train"])
  743. case _:
  744. assert False, "param {param} not supported"
  745. assert param_value == value, f"model param {param} {value} != {param_value}"
  746. async def concurrent_requests(context, f_completion, *args, **kwargs):
  747. context.n_prompts = len(context.prompts)
  748. if context.debug:
  749. print(f"starting {context.n_prompts} concurrent completion requests...")
  750. assert context.n_prompts > 0
  751. seeds = await completions_seed(context)
  752. assert seeds is not None
  753. for prompt_no in range(context.n_prompts):
  754. shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
  755. context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
  756. await asyncio.sleep(0.01)
  757. @step('the slot {slot_id:d} is saved with filename "{filename}"')
  758. @async_run_until_complete
  759. async def step_save_slot(context, slot_id, filename):
  760. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  761. async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
  762. json={"filename": filename},
  763. headers={"Content-Type": "application/json"}) as response:
  764. context.response = response
  765. @step('the slot {slot_id:d} is restored with filename "{filename}"')
  766. @async_run_until_complete
  767. async def step_restore_slot(context, slot_id, filename):
  768. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  769. async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
  770. json={"filename": filename},
  771. headers={"Content-Type": "application/json"}) as response:
  772. context.response = response
  773. @step('the slot {slot_id:d} is erased')
  774. @async_run_until_complete
  775. async def step_erase_slot(context, slot_id):
  776. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  777. async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
  778. headers={"Content-Type": "application/json"}) as response:
  779. context.response = response
  780. @step('switch {on_or_off} lora adapter {lora_id:d}')
  781. @async_run_until_complete
  782. async def toggle_lora_adapter(context, on_or_off: str, lora_id: int):
  783. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  784. async with session.post(f'{context.base_url}/lora-adapters',
  785. json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}],
  786. headers={"Content-Type": "application/json"}) as response:
  787. context.response = response
  788. print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}])
  789. @step('the server responds with status code {status_code:d}')
  790. def step_server_responds_with_status_code(context, status_code):
  791. assert context.response.status == status_code
  792. async def request_completion(prompt,
  793. seed,
  794. base_url,
  795. debug=False,
  796. prompt_prefix=None,
  797. prompt_suffix=None,
  798. n_predict=None,
  799. cache_prompt=False,
  800. id_slot=None,
  801. expect_api_error=None,
  802. user_api_key=None,
  803. temperature=None) -> int | dict[str, Any]:
  804. if debug:
  805. print(f"Sending completion request: {prompt}")
  806. origin = "my.super.domain"
  807. headers = {
  808. 'Origin': origin
  809. }
  810. if user_api_key is not None:
  811. if debug:
  812. print(f"Set user_api_key: {user_api_key}")
  813. headers['Authorization'] = f'Bearer {user_api_key}'
  814. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  815. async with session.post(f'{base_url}/completion',
  816. json={
  817. "input_prefix": prompt_prefix,
  818. "prompt": prompt,
  819. "input_suffix": prompt_suffix,
  820. "n_predict": n_predict if n_predict is not None else -1,
  821. "cache_prompt": cache_prompt,
  822. "id_slot": id_slot,
  823. "seed": seed if seed is not None else 42,
  824. "temperature": temperature if temperature is not None else 0.8,
  825. "n_probs": 2,
  826. },
  827. headers=headers) as response:
  828. if expect_api_error is None or not expect_api_error:
  829. assert response.status == 200
  830. assert response.headers['Access-Control-Allow-Origin'] == origin
  831. return await response.json()
  832. else:
  833. return response.status
  834. async def oai_chat_completions(user_prompt,
  835. seed,
  836. system_prompt,
  837. base_url: str,
  838. base_path: str,
  839. async_client,
  840. debug=False,
  841. temperature=None,
  842. model=None,
  843. n_predict=None,
  844. enable_streaming=None,
  845. response_format=None,
  846. user_api_key=None,
  847. expect_api_error=None) -> int | dict[str, Any]:
  848. if debug:
  849. print(f"Sending OAI Chat completions request: {user_prompt}")
  850. # openai client always expects an api key
  851. user_api_key = user_api_key if user_api_key is not None else 'nope'
  852. seed = seed if seed is not None else 42
  853. enable_streaming = enable_streaming if enable_streaming is not None else False
  854. payload = {
  855. "messages": [
  856. {
  857. "role": "system",
  858. "content": system_prompt,
  859. },
  860. {
  861. "role": "user",
  862. "content": user_prompt,
  863. }
  864. ],
  865. "model": model,
  866. "max_tokens": n_predict,
  867. "stream": enable_streaming,
  868. "temperature": temperature if temperature is not None else 0.0,
  869. "seed": seed,
  870. }
  871. if response_format is not None:
  872. payload['response_format'] = response_format
  873. completion_response = {
  874. 'content': '',
  875. 'timings': {
  876. 'predicted_n': 0,
  877. 'prompt_n': 0
  878. }
  879. }
  880. if async_client:
  881. origin = 'llama.cpp'
  882. headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
  883. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  884. async with session.post(f'{base_url}{base_path}',
  885. json=payload,
  886. headers=headers) as response:
  887. if enable_streaming:
  888. assert response.status == 200
  889. assert response.headers['Access-Control-Allow-Origin'] == origin
  890. assert response.headers['Content-Type'] == "text/event-stream"
  891. event_received = True
  892. while event_received:
  893. event_received = False
  894. async for line_in_bytes in response.content:
  895. line = line_in_bytes.decode('utf-8')
  896. line = line.rstrip('\n').rstrip('\r')
  897. if line == '':
  898. continue
  899. event_data = line.split(': ', 1)
  900. assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```'
  901. chunk_raw = event_data[1]
  902. if chunk_raw == '[DONE]':
  903. break
  904. chunk = json.loads(chunk_raw)
  905. assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```"
  906. delta = chunk['choices'][0]['delta']
  907. if 'content' in delta:
  908. completion_response['content'] += delta['content']
  909. completion_response['timings']['predicted_n'] += 1
  910. else:
  911. if expect_api_error is None or not expect_api_error:
  912. assert response.status == 200
  913. assert response.headers['Access-Control-Allow-Origin'] == origin
  914. assert response.headers['Content-Type'] == "application/json; charset=utf-8"
  915. chat_completion_raw = await response.json()
  916. completion_response = {
  917. 'content': chat_completion_raw['choices'][0]['message'],
  918. 'timings': {
  919. 'predicted_n': chat_completion_raw['usage']['completion_tokens'],
  920. 'prompt_n': chat_completion_raw['usage']['prompt_tokens']
  921. }
  922. }
  923. else:
  924. return response.status
  925. else:
  926. try:
  927. openai.api_key = user_api_key
  928. openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
  929. assert model is not None
  930. chat_completion = openai.chat.completions.create(
  931. messages=payload['messages'],
  932. model=model,
  933. max_tokens=n_predict,
  934. stream=enable_streaming,
  935. response_format=payload.get('response_format') or openai.NOT_GIVEN,
  936. seed=seed,
  937. temperature=payload['temperature']
  938. )
  939. except openai.AuthenticationError as e:
  940. if expect_api_error is not None and expect_api_error:
  941. return 401
  942. else:
  943. assert False, f'error raised: {e}'
  944. if enable_streaming:
  945. chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
  946. for chunk in chat_completion:
  947. assert len(chunk.choices) == 1
  948. delta = chunk.choices[0].delta
  949. if delta.content is not None:
  950. completion_response['content'] += delta.content
  951. completion_response['timings']['predicted_n'] += 1
  952. completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
  953. else:
  954. assert len(chat_completion.choices) == 1
  955. assert chat_completion.usage is not None
  956. completion_response = {
  957. 'content': chat_completion.choices[0].message.content,
  958. 'timings': {
  959. 'predicted_n': chat_completion.usage.completion_tokens,
  960. 'prompt_n': chat_completion.usage.prompt_tokens
  961. },
  962. 'truncated': chat_completion.choices[0].finish_reason != 'stop'
  963. }
  964. if debug:
  965. print("OAI response formatted to llama.cpp:", completion_response)
  966. return completion_response
  967. async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
  968. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  969. async with session.post(f'{base_url}/embedding',
  970. json={
  971. "content": content,
  972. }) as response:
  973. if response.status == 200:
  974. response_json = await response.json()
  975. return [response_json['embedding']]
  976. else:
  977. return response.status
  978. async def request_oai_embeddings(input, seed,
  979. base_url=None, user_api_key=None,
  980. model=None, async_client=False) -> list[list[float]]:
  981. # openai client always expects an api_key
  982. user_api_key = user_api_key if user_api_key is not None else 'nope'
  983. if async_client:
  984. origin = 'llama.cpp'
  985. headers=[]
  986. if user_api_key is not None:
  987. headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
  988. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  989. async with session.post(f'{base_url}/v1/embeddings',
  990. json={
  991. "input": input,
  992. "model": model,
  993. },
  994. headers=headers) as response:
  995. assert response.status == 200, f"received status code not expected: {response.status}"
  996. assert response.headers['Access-Control-Allow-Origin'] == origin
  997. assert response.headers['Content-Type'] == "application/json; charset=utf-8"
  998. response_json = await response.json()
  999. assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
  1000. assert response_json['object'] == 'list'
  1001. if isinstance(input, Sequence):
  1002. embeddings = []
  1003. for an_oai_embeddings in response_json['data']:
  1004. embeddings.append(an_oai_embeddings['embedding'])
  1005. else:
  1006. embeddings = [response_json['data']['embedding']]
  1007. return embeddings
  1008. else:
  1009. openai.api_key = user_api_key
  1010. openai.base_url = f'{base_url}/v1/'
  1011. assert model is not None
  1012. oai_embeddings = openai.embeddings.create(
  1013. model=model,
  1014. input=input,
  1015. )
  1016. return [e.embedding for e in oai_embeddings.data]
  1017. def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
  1018. content = completion_response['content']
  1019. n_predicted = completion_response['timings']['predicted_n']
  1020. assert len(content) > 0, "no token predicted"
  1021. if re_content is not None:
  1022. p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
  1023. matches = p.finditer(content)
  1024. last_match = 0
  1025. highlighted = ''
  1026. for match in matches:
  1027. start, end = match.span()
  1028. highlighted += content[last_match: start]
  1029. highlighted += '\x1b[33m'
  1030. highlighted += content[start: end]
  1031. highlighted += '\x1b[0m'
  1032. last_match = end
  1033. highlighted += content[last_match:]
  1034. if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
  1035. print(f"Checking completion response: {highlighted}")
  1036. assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
  1037. if expected_predicted_n and expected_predicted_n > 0:
  1038. assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
  1039. f' {n_predicted} <> {expected_predicted_n}')
  1040. def assert_all_predictions_equal(completion_responses):
  1041. if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
  1042. for i, response_i in enumerate(completion_responses):
  1043. content_i = response_i['content']
  1044. print(f"content {i}: {content_i}")
  1045. for i, response_i in enumerate(completion_responses):
  1046. content_i = response_i['content']
  1047. for j, response_j in enumerate(completion_responses):
  1048. if i == j:
  1049. continue
  1050. content_j = response_j['content']
  1051. assert content_i == content_j, "contents not equal"
  1052. def assert_all_predictions_different(completion_responses):
  1053. if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
  1054. for i, response_i in enumerate(completion_responses):
  1055. content_i = response_i['content']
  1056. print(f"content {i}: {content_i}")
  1057. for i, response_i in enumerate(completion_responses):
  1058. content_i = response_i['content']
  1059. for j, response_j in enumerate(completion_responses):
  1060. if i == j:
  1061. continue
  1062. content_j = response_j['content']
  1063. assert content_i != content_j, "contents not different"
  1064. def assert_all_token_probabilities_equal(completion_responses):
  1065. n_predict = len(completion_responses[0]['completion_probabilities'])
  1066. if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
  1067. for pos in range(n_predict):
  1068. for i, response_i in enumerate(completion_responses):
  1069. probs_i = response_i['completion_probabilities'][pos]['probs']
  1070. print(f"pos {pos}, probs {i}: {probs_i}")
  1071. for pos in range(n_predict):
  1072. for i, response_i in enumerate(completion_responses):
  1073. probs_i = response_i['completion_probabilities'][pos]['probs']
  1074. for j, response_j in enumerate(completion_responses):
  1075. if i == j:
  1076. continue
  1077. probs_j = response_j['completion_probabilities'][pos]['probs']
  1078. assert probs_i == probs_j, "contents not equal"
  1079. async def gather_tasks_results(context):
  1080. n_tasks = len(context.concurrent_tasks)
  1081. if context.debug:
  1082. print(f"Waiting for all {n_tasks} tasks results...")
  1083. for task_no in range(n_tasks):
  1084. context.tasks_result.append(await context.concurrent_tasks.pop())
  1085. n_completions = len(context.tasks_result)
  1086. return n_completions
  1087. async def wait_for_slots_status(context,
  1088. base_url,
  1089. expected_http_status_code,
  1090. timeout=3,
  1091. params=None,
  1092. slots_idle=None,
  1093. slots_processing=None):
  1094. if context.debug:
  1095. print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}")
  1096. interval = 0.5
  1097. counter = 0
  1098. if 'GITHUB_ACTIONS' in os.environ:
  1099. timeout *= 2
  1100. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  1101. while True:
  1102. headers = {'Authorization': f'Bearer {context.server_api_key}'}
  1103. async with await session.get(f'{base_url}/slots', params=params, headers=headers) as slots_response:
  1104. status_code = slots_response.status
  1105. slots = await slots_response.json()
  1106. if context.debug:
  1107. print(f"slots responses {slots}\n")
  1108. if status_code == 503 and status_code == expected_http_status_code:
  1109. return
  1110. if status_code == 200 and status_code == expected_http_status_code:
  1111. n_slots_idle = sum(1 if not slot["is_processing"] else 0 for slot in slots)
  1112. n_slots_processing = sum(1 if slot["is_processing"] else 0 for slot in slots)
  1113. if ((slots_idle is None or slots_idle == n_slots_idle)
  1114. and (slots_processing is None or slots_processing == n_slots_processing)):
  1115. return
  1116. await asyncio.sleep(interval)
  1117. counter += interval
  1118. if counter >= timeout:
  1119. # Sometimes health requests are triggered after completions are predicted
  1120. if expected_http_status_code == 503:
  1121. if len(context.tasks_result) == 0:
  1122. print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
  1123. " busy health check missed, probably too fast inference\x1b[0m\n")
  1124. n_completions = await gather_tasks_results(context)
  1125. if n_completions > 0:
  1126. return
  1127. assert False, f'slots check timeout exceeded {counter}s>={timeout}'
  1128. def assert_embeddings(embeddings):
  1129. assert len(embeddings) > 0
  1130. embeddings_computed = False
  1131. for emb in embeddings:
  1132. if not isinstance(emb, float):
  1133. assert False, f"Bad embeddings: {embeddings}"
  1134. if emb != 0:
  1135. embeddings_computed = True
  1136. assert embeddings_computed, f"Embeddings: {embeddings}"
  1137. async def request_slots_status(context, expected_slots):
  1138. async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
  1139. async with await session.get(f'{context.base_url}/slots') as slots_response:
  1140. assert slots_response.status == 200
  1141. slots = await slots_response.json()
  1142. assert_slots_status(slots, expected_slots)
  1143. def assert_slots_status(slots, expected_slots):
  1144. assert len(slots) == len(expected_slots)
  1145. for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
  1146. for key in expected:
  1147. assert expected[key] == slot[key], (f"invalid slot {slot_id}"
  1148. f" expected[{key}] != slot[{key}]"
  1149. f" = {expected[key]} != {slot[key]}")
  1150. async def completions_seed(context, num_seeds=None):
  1151. if hasattr(context, "seed") and context.seed is not None:
  1152. assert len(context.seed) == context.n_prompts
  1153. if num_seeds is None:
  1154. num_seeds = context.n_prompts
  1155. assert num_seeds <= context.n_prompts
  1156. seeds = context.seed[:num_seeds]
  1157. context.seed = context.seed[num_seeds:] if num_seeds < context.n_prompts else None
  1158. return seeds
  1159. if hasattr(context, "server_seed") and context.server_seed is not None:
  1160. if num_seeds is None:
  1161. return [context.server_seed] * context.n_prompts
  1162. else:
  1163. return [context.server_seed] * num_seeds
  1164. return None
  1165. def context_text(context):
  1166. return context.text.replace('\r', '')
  1167. def start_server_background(context):
  1168. if os.name == 'nt':
  1169. context.server_path = '../../../build/bin/Release/llama-server.exe'
  1170. else:
  1171. context.server_path = '../../../build/bin/llama-server'
  1172. if 'LLAMA_SERVER_BIN_PATH' in os.environ:
  1173. context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
  1174. server_listen_addr = context.server_fqdn
  1175. server_args = [
  1176. '--slots', # requires to get slot status via /slots endpoint
  1177. '--host', server_listen_addr,
  1178. '--port', context.server_port,
  1179. ]
  1180. if context.model_file:
  1181. server_args.extend(['--model', context.model_file])
  1182. if context.model_url:
  1183. server_args.extend(['--model-url', context.model_url])
  1184. if context.model_hf_repo:
  1185. server_args.extend(['--hf-repo', context.model_hf_repo])
  1186. if context.model_hf_file:
  1187. server_args.extend(['--hf-file', context.model_hf_file])
  1188. if context.n_batch:
  1189. server_args.extend(['--batch-size', context.n_batch])
  1190. if context.n_ubatch:
  1191. server_args.extend(['--ubatch-size', context.n_ubatch])
  1192. if context.n_threads:
  1193. server_args.extend(['--threads', context.threads])
  1194. if context.n_gpu_layer:
  1195. server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
  1196. if context.draft is not None:
  1197. server_args.extend(['--draft', context.draft])
  1198. if context.server_continuous_batching:
  1199. server_args.append('--cont-batching')
  1200. if context.server_embeddings:
  1201. server_args.append('--embedding')
  1202. if context.server_reranking:
  1203. server_args.append('--reranking')
  1204. if context.server_metrics:
  1205. server_args.append('--metrics')
  1206. if context.model_alias:
  1207. server_args.extend(['--alias', context.model_alias])
  1208. if context.n_ctx:
  1209. server_args.extend(['--ctx-size', context.n_ctx])
  1210. if context.n_slots:
  1211. server_args.extend(['--parallel', context.n_slots])
  1212. if context.n_server_predict:
  1213. server_args.extend(['--n-predict', context.n_server_predict])
  1214. if context.slot_save_path:
  1215. server_args.extend(['--slot-save-path', context.slot_save_path])
  1216. if context.server_api_key:
  1217. server_args.extend(['--api-key', context.server_api_key])
  1218. if context.n_ga:
  1219. server_args.extend(['--grp-attn-n', context.n_ga])
  1220. if context.n_ga_w:
  1221. server_args.extend(['--grp-attn-w', context.n_ga_w])
  1222. if context.debug:
  1223. server_args.append('--verbose')
  1224. if context.lora_file:
  1225. server_args.extend(['--lora', context.lora_file])
  1226. if context.disable_ctx_shift:
  1227. server_args.extend(['--no-context-shift'])
  1228. args = [str(arg) for arg in [context.server_path, *server_args]]
  1229. print(f"bench: starting server with: {' '.join(args)}")
  1230. flags = 0
  1231. if 'nt' == os.name:
  1232. flags |= subprocess.DETACHED_PROCESS
  1233. flags |= subprocess.CREATE_NEW_PROCESS_GROUP
  1234. flags |= subprocess.CREATE_NO_WINDOW
  1235. pkwargs = {
  1236. 'creationflags': flags,
  1237. 'stdout': subprocess.PIPE,
  1238. 'stderr': subprocess.PIPE
  1239. }
  1240. context.server_process = subprocess.Popen(
  1241. [str(arg) for arg in [context.server_path, *server_args]],
  1242. **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
  1243. def server_log(in_stream, out_stream):
  1244. for line in iter(in_stream.readline, b''):
  1245. print(line.decode('utf-8'), end='', file=out_stream)
  1246. thread_stdout = threading.Thread(target=server_log, args=(context.server_process.stdout, sys.stdout))
  1247. thread_stdout.start()
  1248. thread_stderr = threading.Thread(target=server_log, args=(context.server_process.stderr, sys.stderr))
  1249. thread_stderr.start()
  1250. print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}")