steps.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018
  1. import asyncio
  2. import collections
  3. import json
  4. import os
  5. import re
  6. import socket
  7. import subprocess
  8. import time
  9. from contextlib import closing
  10. from re import RegexFlag
  11. import aiohttp
  12. import numpy as np
  13. import openai
  14. from behave import step
  15. from behave.api.async_step import async_run_until_complete
  16. from huggingface_hub import hf_hub_download
  17. from prometheus_client import parser
  18. @step(u"a server listening on {server_fqdn}:{server_port}")
  19. def step_server_config(context, server_fqdn, server_port):
  20. context.server_fqdn = server_fqdn
  21. context.server_port = int(server_port)
  22. if 'PORT' in os.environ:
  23. context.server_port = int(os.environ['PORT'])
  24. print(f"$PORT set, overriding server port with to {context.server_port}")
  25. if 'FQDN' in os.environ:
  26. context.server_fqdn = os.environ['FQDN']
  27. print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
  28. context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
  29. context.model_alias = None
  30. context.n_batch = None
  31. context.n_ctx = None
  32. context.n_ga = None
  33. context.n_ga_w = None
  34. context.n_gpu_layer = None
  35. context.n_predict = None
  36. context.n_prompts = 0
  37. context.n_server_predict = None
  38. context.n_slots = None
  39. context.prompt_prefix = None
  40. context.prompt_suffix = None
  41. context.server_api_key = None
  42. context.server_continuous_batching = False
  43. context.server_embeddings = False
  44. context.server_metrics = False
  45. context.server_process = None
  46. context.seed = None
  47. context.server_seed = None
  48. context.user_api_key = None
  49. context.tasks_result = []
  50. context.concurrent_tasks = []
  51. context.prompts = []
  52. @step(u'a model file {hf_file} from HF repo {hf_repo}')
  53. def step_download_hf_model(context, hf_file, hf_repo):
  54. context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
  55. if context.debug:
  56. print(f"model file: {context.model_file}\n")
  57. @step(u'a model alias {model_alias}')
  58. def step_model_alias(context, model_alias):
  59. context.model_alias = model_alias
  60. @step(u'{seed:d} as server seed')
  61. def step_seed(context, seed):
  62. context.server_seed = seed
  63. @step(u'{ngl:d} GPU offloaded layers')
  64. def step_n_gpu_layer(context, ngl):
  65. if 'N_GPU_LAYERS' in os.environ:
  66. new_ngl = int(os.environ['N_GPU_LAYERS'])
  67. if context.debug:
  68. print(f"-ngl upgraded from {ngl} to {new_ngl}")
  69. ngl = new_ngl
  70. context.n_gpu_layer = ngl
  71. @step(u'{n_ctx:d} KV cache size')
  72. def step_n_ctx(context, n_ctx):
  73. context.n_ctx = n_ctx
  74. @step(u'{n_slots:d} slots')
  75. def step_n_slots(context, n_slots):
  76. context.n_slots = n_slots
  77. @step(u'{n_predict:d} server max tokens to predict')
  78. def step_server_n_predict(context, n_predict):
  79. context.n_server_predict = n_predict
  80. @step(u'continuous batching')
  81. def step_server_continuous_batching(context):
  82. context.server_continuous_batching = True
  83. @step(u'embeddings extraction')
  84. def step_server_embeddings(context):
  85. context.server_embeddings = True
  86. @step(u'prometheus compatible metrics exposed')
  87. def step_server_metrics(context):
  88. context.server_metrics = True
  89. @step(u"the server is starting")
  90. def step_start_server(context):
  91. start_server_background(context)
  92. attempts = 0
  93. while True:
  94. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
  95. result = sock.connect_ex((context.server_fqdn, context.server_port))
  96. if result == 0:
  97. print("\x1b[33;46mserver started!\x1b[0m")
  98. return
  99. attempts += 1
  100. if attempts > 20:
  101. assert False, "server not started"
  102. print(f"waiting for server to start, connect error code = {result}...")
  103. time.sleep(0.1)
  104. @step(u"the server is {expecting_status}")
  105. @async_run_until_complete
  106. async def step_wait_for_the_server_to_be_started(context, expecting_status):
  107. match expecting_status:
  108. case 'healthy':
  109. await wait_for_health_status(context, context.base_url, 200, 'ok')
  110. case 'ready' | 'idle':
  111. await wait_for_health_status(context, context.base_url, 200, 'ok',
  112. timeout=10,
  113. params={'fail_on_no_slot': 0, 'include_slots': 0},
  114. slots_idle=context.n_slots,
  115. slots_processing=0,
  116. expected_slots=[{'id': slot_id, 'state': 0}
  117. for slot_id in
  118. range(context.n_slots if context.n_slots else 1)])
  119. case 'busy':
  120. await wait_for_health_status(context, context.base_url, 503,
  121. 'no slot available',
  122. params={'fail_on_no_slot': 0, 'include_slots': 0},
  123. slots_idle=0,
  124. slots_processing=context.n_slots,
  125. expected_slots=[{'id': slot_id, 'state': 1}
  126. for slot_id in
  127. range(context.n_slots if context.n_slots else 1)])
  128. case _:
  129. assert False, "unknown status"
  130. @step(u'all slots are {expected_slot_status_string}')
  131. @async_run_until_complete
  132. async def step_all_slots_status(context, expected_slot_status_string):
  133. match expected_slot_status_string:
  134. case 'idle':
  135. expected_slot_status = 0
  136. case 'busy':
  137. expected_slot_status = 1
  138. case _:
  139. assert False, "unknown status"
  140. expected_slots = [{'id': slot_id, 'state': expected_slot_status}
  141. for slot_id in range(context.n_slots)]
  142. await request_slots_status(context, expected_slots)
  143. @step(u'a completion request with {api_error} api error')
  144. @async_run_until_complete
  145. async def step_request_completion(context, api_error):
  146. expect_api_error = api_error == 'raised'
  147. completion = await request_completion(context.prompts.pop(),
  148. context.base_url,
  149. debug=context.debug,
  150. n_predict=context.n_predict,
  151. seed=await completions_seed(context),
  152. expect_api_error=expect_api_error,
  153. user_api_key=context.user_api_key)
  154. context.tasks_result.append(completion)
  155. if context.debug:
  156. print(f"Completion response: {completion}\n")
  157. if expect_api_error:
  158. assert completion == 401, f"completion must be an 401 status code: {completion}"
  159. @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
  160. def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
  161. assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
  162. @step(u'{predicted_n:d} tokens are predicted')
  163. def step_n_tokens_predicted(context, predicted_n):
  164. assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
  165. @step(u'a user prompt {user_prompt}')
  166. def step_user_prompt(context, user_prompt):
  167. context.prompts.append(user_prompt)
  168. context.n_prompts = len(context.prompts)
  169. @step(u'a system prompt {system_prompt}')
  170. def step_system_prompt(context, system_prompt):
  171. context.system_prompt = system_prompt
  172. @step(u'a model {model}')
  173. def step_model(context, model):
  174. context.model = model
  175. @step(u'{max_tokens:d} max tokens to predict')
  176. def step_max_tokens(context, max_tokens):
  177. context.n_predict = max_tokens
  178. @step(u'streaming is {enable_streaming}')
  179. def step_streaming(context, enable_streaming):
  180. context.enable_streaming = enable_streaming == 'enabled'
  181. @step(u'a user api key {user_api_key}')
  182. def step_user_api_key(context, user_api_key):
  183. context.user_api_key = user_api_key
  184. @step(u'no user api key')
  185. def step_no_user_api_key(context):
  186. context.user_api_key = None
  187. @step(u'a user api key ')
  188. def step_no_user_api_key_space(context):
  189. context.user_api_key = None
  190. @step(u'a server api key {server_api_key}')
  191. def step_server_api_key(context, server_api_key):
  192. context.server_api_key = server_api_key
  193. @step(u'{n_junk:d} as number of junk')
  194. def step_n_junk(context, n_junk):
  195. context.n_junk = n_junk
  196. @step(u'{n_batch:d} as batch size')
  197. def step_n_batch(context, n_batch):
  198. context.n_batch = n_batch
  199. @step(u'{seed:d} as seed')
  200. def step_seed(context, seed):
  201. context.seed = seed
  202. @step(u'a prefix prompt')
  203. def step_prompt_prefix(context):
  204. context.prompt_prefix = context.text
  205. @step(u'a junk suffix prompt')
  206. def step_prompt_junk_suffix(context):
  207. context.prompt_junk_suffix = context.text
  208. @step(u'a suffix prompt')
  209. def step_prompt_suffix(context):
  210. context.prompt_suffix = context.text
  211. @step(u'{n_ga:d} group attention factor'
  212. u' to extend context size through self-extend')
  213. def step_impl(context, n_ga):
  214. context.n_ga = n_ga
  215. @step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
  216. def step_impl(context, n_ga_w):
  217. context.n_ga_w = n_ga_w
  218. @step(u'a passkey prompt template')
  219. def step_prompt_passkey(context):
  220. context.prompt_passkey = context.text
  221. @step(u'{n_prompts:d} fixed prompts')
  222. def step_fixed_prompts(context, n_prompts):
  223. context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
  224. context.n_prompts = n_prompts
  225. @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
  226. def step_prompt_passkey(context, passkey, i_pos):
  227. prompt = ""
  228. for i in range(context.n_junk):
  229. if i % context.n_junk == i_pos:
  230. prompt += context.prompt_passkey # the passkey is already substituted
  231. prompt += context.prompt_junk_suffix
  232. if context.debug:
  233. passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
  234. print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
  235. context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
  236. context.n_prompts = len(context.prompts)
  237. @step(u'an OAI compatible chat completions request with {api_error} api error')
  238. @async_run_until_complete
  239. async def step_oai_chat_completions(context, api_error):
  240. if context.debug:
  241. print(f"Submitting OAI compatible completions request...\n")
  242. expect_api_error = api_error == 'raised'
  243. completion = await oai_chat_completions(context.prompts.pop(),
  244. context.system_prompt,
  245. context.base_url,
  246. '/v1/chat',
  247. False,
  248. model=context.model if hasattr(context, 'model') else None,
  249. n_predict=context.n_predict
  250. if hasattr(context, 'n_predict') else None,
  251. enable_streaming=context.enable_streaming
  252. if hasattr(context, 'enable_streaming') else None,
  253. seed=await completions_seed(context),
  254. user_api_key=context.user_api_key
  255. if hasattr(context, 'user_api_key') else None,
  256. expect_api_error=expect_api_error)
  257. context.tasks_result.append(completion)
  258. if context.debug:
  259. print(f"Completion response: {completion}")
  260. if expect_api_error:
  261. assert completion == 401, f"completion must be an 401 status code: {completion}"
  262. if context.debug:
  263. print(f"Completion response: {completion}")
  264. @step(u'a prompt')
  265. def step_a_prompt(context):
  266. context.prompts.append(context.text)
  267. context.n_prompts = len(context.prompts)
  268. @step(u'a prompt {prompt}')
  269. def step_a_prompt_prompt(context, prompt):
  270. context.prompts.append(prompt)
  271. context.n_prompts = len(context.prompts)
  272. @step(u'concurrent completion requests')
  273. @async_run_until_complete()
  274. async def step_concurrent_completion_requests(context):
  275. await concurrent_requests(context,
  276. request_completion,
  277. # prompt is inserted automatically
  278. context.base_url,
  279. debug=context.debug,
  280. prompt_prefix=context.prompt_prefix,
  281. prompt_suffix=context.prompt_suffix,
  282. n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
  283. seed=await completions_seed(context),
  284. user_api_key=context.user_api_key if hasattr(context,
  285. 'user_api_key') else None)
  286. @step(u'concurrent OAI completions requests')
  287. @async_run_until_complete
  288. async def step_oai_chat_completions(context):
  289. await concurrent_requests(context, oai_chat_completions,
  290. # user_prompt is inserted automatically
  291. context.system_prompt,
  292. context.base_url,
  293. '/v1/chat/completions',
  294. True, # async_client
  295. model=context.model
  296. if hasattr(context, 'model') else None,
  297. n_predict=context.n_predict
  298. if hasattr(context, 'n_predict') else None,
  299. enable_streaming=context.enable_streaming
  300. if hasattr(context, 'enable_streaming') else None,
  301. seed=await completions_seed(context),
  302. user_api_key=context.user_api_key
  303. if hasattr(context, 'user_api_key') else None)
  304. @step(u'concurrent OAI completions requests no v1')
  305. @async_run_until_complete
  306. async def step_oai_chat_completions(context):
  307. await concurrent_requests(context, oai_chat_completions,
  308. # user_prompt is inserted automatically
  309. context.system_prompt,
  310. context.base_url,
  311. '/chat/completions',
  312. True, # async_client
  313. model=context.model
  314. if hasattr(context, 'model') else None,
  315. n_predict=context.n_predict
  316. if hasattr(context, 'n_predict') else None,
  317. enable_streaming=context.enable_streaming
  318. if hasattr(context, 'enable_streaming') else None,
  319. seed=context.seed
  320. if hasattr(context, 'seed') else
  321. context.server_seed
  322. if hasattr(context, 'server_seed') else None,
  323. user_api_key=context.user_api_key
  324. if hasattr(context, 'user_api_key') else None)
  325. @step(u'all prompts are predicted')
  326. @async_run_until_complete
  327. async def step_all_prompts_are_predicted(context):
  328. await all_prompts_are_predicted(context)
  329. @step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
  330. @async_run_until_complete
  331. async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
  332. await all_prompts_are_predicted(context, n_expected_predicted)
  333. async def all_prompts_are_predicted(context, expected_predicted_n=None):
  334. n_completions = await gather_tasks_results(context)
  335. assert n_completions > 0
  336. for i in range(n_completions):
  337. assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
  338. assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
  339. @step(u'embeddings are computed for')
  340. @async_run_until_complete
  341. async def step_compute_embedding(context):
  342. context.n_prompts = 1
  343. context.embeddings = await request_embedding(context.text, base_url=context.base_url)
  344. @step(u'all embeddings are the same')
  345. @async_run_until_complete
  346. async def step_all_embeddings_are_the_same(context):
  347. n_embedding_requests = await gather_tasks_results(context)
  348. assert n_embedding_requests > 0
  349. embeddings = []
  350. for i in range(n_embedding_requests):
  351. embedding = context.tasks_result.pop().pop()
  352. embeddings.append(embedding)
  353. assert_embeddings(embedding)
  354. n = len(embeddings)
  355. for i in range(n-1):
  356. for j in range(i+1, n):
  357. embedding1 = np.array(embeddings[i])
  358. embedding2 = np.array(embeddings[j])
  359. if context.debug:
  360. print(f"embedding1: {embedding1[-8:]}\n")
  361. print(f"embedding2: {embedding2[-8:]}\n")
  362. similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
  363. msg = f"Similarity between {i} and {j}: {similarity:.10f}"
  364. if context.debug:
  365. print(f"{msg}\n")
  366. assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
  367. @step(u'embeddings are generated')
  368. def step_assert_embeddings(context):
  369. assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
  370. f"context.n_prompts={context.n_prompts}\n"
  371. f"context.embeddings={context.embeddings}")
  372. for embedding in context.embeddings:
  373. assert_embeddings(embedding)
  374. @step(u'an OAI compatible embeddings computation request for')
  375. @async_run_until_complete
  376. async def step_oai_compute_embeddings(context):
  377. context.n_prompts = 1
  378. context.embeddings = await request_oai_embeddings(context.text,
  379. base_url=context.base_url,
  380. user_api_key=context.user_api_key,
  381. model=context.model)
  382. @step(u'an OAI compatible embeddings computation request for multiple inputs')
  383. @async_run_until_complete
  384. async def step_oai_compute_embeddings_multiple_inputs(context):
  385. context.embeddings = await request_oai_embeddings(context.prompts,
  386. base_url=context.base_url,
  387. user_api_key=context.user_api_key,
  388. model=context.model)
  389. context.prompts.clear()
  390. @step(u'concurrent embedding requests')
  391. @async_run_until_complete()
  392. async def step_concurrent_embedding_requests(context):
  393. await concurrent_requests(context,
  394. request_embedding,
  395. # prompt is inserted automatically
  396. base_url=context.base_url)
  397. @step(u'concurrent OAI embedding requests')
  398. @async_run_until_complete()
  399. async def step_concurrent_oai_embedding_requests(context):
  400. await concurrent_requests(context,
  401. request_oai_embeddings,
  402. # prompt is inserted automatically
  403. base_url=context.base_url,
  404. async_client=True,
  405. model=context.model)
  406. @step(u'all embeddings are generated')
  407. @async_run_until_complete()
  408. async def all_embeddings_are_generated(context):
  409. n_embedding_requests = await gather_tasks_results(context)
  410. assert n_embedding_requests == context.n_prompts
  411. for i in range(n_embedding_requests):
  412. assert_embeddings(context.tasks_result.pop().pop())
  413. @step(u'tokenizing')
  414. @async_run_until_complete
  415. async def step_tokenize(context):
  416. context.tokenized_text = context.text
  417. async with aiohttp.ClientSession() as session:
  418. async with session.post(f'{context.base_url}/tokenize',
  419. json={
  420. "content": context.tokenized_text,
  421. }) as response:
  422. assert response.status == 200
  423. tokenize_json = await response.json()
  424. context.tokens = tokenize_json['tokens']
  425. @step(u'tokens can be detokenize')
  426. @async_run_until_complete
  427. async def step_detokenize(context):
  428. assert len(context.tokens) > 0
  429. async with aiohttp.ClientSession() as session:
  430. async with session.post(f'{context.base_url}/detokenize',
  431. json={
  432. "tokens": context.tokens,
  433. }) as response:
  434. assert response.status == 200
  435. detokenize_json = await response.json()
  436. # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15
  437. assert context.tokenized_text == detokenize_json['content'].strip()
  438. @step(u'an OPTIONS request is sent from {origin}')
  439. @async_run_until_complete
  440. async def step_options_request(context, origin):
  441. async with aiohttp.ClientSession() as session:
  442. async with session.options(f'{context.base_url}/v1/chat/completions',
  443. headers={"Origin": origin}) as response:
  444. assert response.status == 200
  445. context.options_response = response
  446. @step(u'CORS header {cors_header} is set to {cors_header_value}')
  447. def step_check_options_header_value(context, cors_header, cors_header_value):
  448. assert context.options_response.headers[cors_header] == cors_header_value
  449. @step(u'prometheus metrics are exposed')
  450. @async_run_until_complete
  451. async def step_prometheus_metrics_exported(context):
  452. async with aiohttp.ClientSession() as session:
  453. async with await session.get(f'{context.base_url}/metrics') as metrics_response:
  454. assert metrics_response.status == 200
  455. assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
  456. metrics_raw = await metrics_response.text()
  457. metric_exported = False
  458. if context.debug:
  459. print(f"/metrics answer:\n{metrics_raw}\n")
  460. for metric in parser.text_string_to_metric_families(metrics_raw):
  461. match metric.name:
  462. case "llamacpp:kv_cache_usage_ratio":
  463. assert len(metric.samples) > 0
  464. metric_exported = True
  465. assert metric_exported, "No metrics exported"
  466. @step(u'available models')
  467. def step_available_models(context):
  468. # openai client always expects an api_key
  469. openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
  470. openai.api_base = f'{context.base_url}/v1'
  471. context.models = openai.Model.list().data
  472. @step(u'{n_model:d} models are supported')
  473. def step_supported_models(context, n_model):
  474. if context.debug:
  475. print("server models available:", context.models)
  476. assert len(context.models) == n_model
  477. @step(u'model {i_model:d} is {param} {preposition} {param_value}')
  478. def step_supported_models(context, i_model, param, preposition, param_value):
  479. assert i_model < len(context.models)
  480. model = context.models[i_model]
  481. param_value = param_value.split(' ', 1)[0]
  482. match param:
  483. case 'identified':
  484. value = model.id
  485. case 'trained':
  486. value = str(model.meta.n_ctx_train)
  487. case _:
  488. assert False, "param {param} not supported"
  489. assert param_value == value, f"model param {param} {value} != {param_value}"
  490. async def concurrent_requests(context, f_completion, *args, **kwargs):
  491. context.n_prompts = len(context.prompts)
  492. if context.debug:
  493. print(f"starting {context.n_prompts} concurrent completion requests...")
  494. assert context.n_prompts > 0
  495. for prompt_no in range(context.n_prompts):
  496. shifted_args = [context.prompts.pop(), *args]
  497. context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
  498. await asyncio.sleep(0.1)
  499. async def request_completion(prompt,
  500. base_url,
  501. debug=False,
  502. prompt_prefix=None,
  503. prompt_suffix=None,
  504. n_predict=None,
  505. seed=None,
  506. expect_api_error=None,
  507. user_api_key=None):
  508. if debug:
  509. print(f"Sending completion request: {prompt}")
  510. origin = "my.super.domain"
  511. headers = {
  512. 'Origin': origin
  513. }
  514. if user_api_key is not None:
  515. if debug:
  516. print(f"Set user_api_key: {user_api_key}")
  517. headers['Authorization'] = f'Bearer {user_api_key}'
  518. async with aiohttp.ClientSession() as session:
  519. async with session.post(f'{base_url}/completion',
  520. json={
  521. "input_prefix": prompt_prefix,
  522. "prompt": prompt,
  523. "input_suffix": prompt_suffix,
  524. "n_predict": n_predict if n_predict is not None else -1,
  525. "seed": seed if seed is not None else 42
  526. },
  527. headers=headers,
  528. timeout=3600) as response:
  529. if expect_api_error is None or not expect_api_error:
  530. assert response.status == 200
  531. assert response.headers['Access-Control-Allow-Origin'] == origin
  532. return await response.json()
  533. else:
  534. return response.status
  535. async def oai_chat_completions(user_prompt,
  536. system_prompt,
  537. base_url,
  538. base_path,
  539. async_client,
  540. debug=False,
  541. model=None,
  542. n_predict=None,
  543. enable_streaming=None,
  544. seed=None,
  545. user_api_key=None,
  546. expect_api_error=None):
  547. if debug:
  548. print(f"Sending OAI Chat completions request: {user_prompt}")
  549. # openai client always expects an api key
  550. user_api_key = user_api_key if user_api_key is not None else 'nope'
  551. seed = seed if seed is not None else 42
  552. enable_streaming = enable_streaming if enable_streaming is not None else False
  553. payload = {
  554. "messages": [
  555. {
  556. "role": "system",
  557. "content": system_prompt,
  558. },
  559. {
  560. "role": "user",
  561. "content": user_prompt,
  562. }
  563. ],
  564. "model": model,
  565. "max_tokens": n_predict,
  566. "stream": enable_streaming,
  567. "seed": seed
  568. }
  569. completion_response = {
  570. 'content': '',
  571. 'timings': {
  572. 'predicted_n': 0
  573. }
  574. }
  575. if async_client:
  576. origin = 'llama.cpp'
  577. headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
  578. async with aiohttp.ClientSession() as session:
  579. async with session.post(f'{base_url}{base_path}',
  580. json=payload,
  581. headers=headers) as response:
  582. if enable_streaming:
  583. assert response.status == 200
  584. assert response.headers['Access-Control-Allow-Origin'] == origin
  585. assert response.headers['Content-Type'] == "text/event-stream"
  586. event_received = True
  587. while event_received:
  588. event_received = False
  589. async for line_in_bytes in response.content:
  590. line = line_in_bytes.decode('utf8')
  591. line = line.rstrip('\n').rstrip('\r')
  592. if line == '':
  593. continue
  594. event_data = line.split(': ', 1)
  595. assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```'
  596. chunk_raw = event_data[1]
  597. chunk = json.loads(chunk_raw)
  598. assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```"
  599. delta = chunk['choices'][0]['delta']
  600. if 'content' in delta:
  601. completion_response['content'] += delta['content']
  602. completion_response['timings']['predicted_n'] += 1
  603. else:
  604. if expect_api_error is None or not expect_api_error:
  605. assert response.status == 200
  606. assert response.headers['Access-Control-Allow-Origin'] == origin
  607. assert response.headers['Content-Type'] == "application/json; charset=utf-8"
  608. chat_completion_raw = await response.json()
  609. completion_response = {
  610. 'content': chat_completion_raw['choices'][0]['message'],
  611. 'timings': {
  612. 'predicted_n': chat_completion_raw['usage']['completion_tokens']
  613. }
  614. }
  615. else:
  616. return response.status
  617. else:
  618. try:
  619. openai.api_key = user_api_key
  620. openai.api_base = f'{base_url}{base_path}'
  621. chat_completion = openai.Completion.create(
  622. messages=payload['messages'],
  623. model=model,
  624. max_tokens=n_predict,
  625. stream=enable_streaming,
  626. seed=seed
  627. )
  628. except openai.error.APIError as e:
  629. if expect_api_error is not None and expect_api_error:
  630. return 401
  631. else:
  632. assert False, f'error raised: {e}'
  633. if enable_streaming:
  634. for chunk in chat_completion:
  635. assert len(chunk.choices) == 1
  636. delta = chunk.choices[0].delta
  637. if 'content' in delta:
  638. completion_response['content'] += delta['content']
  639. completion_response['timings']['predicted_n'] += 1
  640. else:
  641. assert len(chat_completion.choices) == 1
  642. completion_response = {
  643. 'content': chat_completion.choices[0].message.content,
  644. 'timings': {
  645. 'predicted_n': chat_completion.usage.completion_tokens
  646. }
  647. }
  648. if debug:
  649. print("OAI response formatted to llama.cpp:", completion_response)
  650. return completion_response
  651. async def request_embedding(content, base_url=None):
  652. async with aiohttp.ClientSession() as session:
  653. async with session.post(f'{base_url}/embedding',
  654. json={
  655. "content": content,
  656. }) as response:
  657. assert response.status == 200
  658. response_json = await response.json()
  659. return [response_json['embedding']]
  660. async def request_oai_embeddings(input,
  661. base_url=None, user_api_key=None,
  662. model=None, async_client=False):
  663. # openai client always expects an api_key
  664. user_api_key = user_api_key if user_api_key is not None else 'nope'
  665. if async_client:
  666. origin = 'llama.cpp'
  667. headers=[]
  668. if user_api_key is not None:
  669. headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
  670. async with aiohttp.ClientSession() as session:
  671. async with session.post(f'{base_url}/v1/embeddings',
  672. json={
  673. "input": input,
  674. "model": model,
  675. },
  676. headers=headers,
  677. timeout=3600) as response:
  678. assert response.status == 200, f"received status code not expected: {response.status}"
  679. assert response.headers['Access-Control-Allow-Origin'] == origin
  680. assert response.headers['Content-Type'] == "application/json; charset=utf-8"
  681. response_json = await response.json()
  682. assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
  683. assert response_json['object'] == 'list'
  684. if isinstance(input, collections.abc.Sequence):
  685. embeddings = []
  686. for an_oai_embeddings in response_json['data']:
  687. embeddings.append(an_oai_embeddings['embedding'])
  688. else:
  689. embeddings = [response_json['data']['embedding']]
  690. return embeddings
  691. else:
  692. openai.api_key = user_api_key
  693. openai.api_base = f'{base_url}/v1'
  694. oai_embeddings = openai.Embedding.create(
  695. model=model,
  696. input=input,
  697. )
  698. if isinstance(input, collections.abc.Sequence):
  699. embeddings = []
  700. for an_oai_embeddings in oai_embeddings.data:
  701. embeddings.append(an_oai_embeddings.embedding)
  702. else:
  703. embeddings = [oai_embeddings.data.embedding]
  704. return embeddings
  705. def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
  706. content = completion_response['content']
  707. n_predicted = completion_response['timings']['predicted_n']
  708. assert len(content) > 0, "no token predicted"
  709. if re_content is not None:
  710. p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
  711. matches = p.finditer(content)
  712. last_match = 0
  713. highlighted = ''
  714. for match in matches:
  715. start, end = match.span()
  716. highlighted += content[last_match: start]
  717. highlighted += '\x1b[33m'
  718. highlighted += content[start: end]
  719. highlighted += '\x1b[0m'
  720. last_match = end
  721. highlighted += content[last_match:]
  722. if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
  723. print(f"Checking completion response: {highlighted}\n")
  724. assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
  725. if expected_predicted_n and expected_predicted_n > 0:
  726. assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
  727. f' {n_predicted} <> {expected_predicted_n}')
  728. async def gather_tasks_results(context):
  729. n_tasks = len(context.concurrent_tasks)
  730. if context.debug:
  731. print(f"Waiting for all {n_tasks} tasks results...\n")
  732. for task_no in range(n_tasks):
  733. context.tasks_result.append(await context.concurrent_tasks.pop())
  734. n_completions = len(context.tasks_result)
  735. return n_completions
  736. async def wait_for_health_status(context,
  737. base_url,
  738. expected_http_status_code,
  739. expected_health_status,
  740. timeout=3,
  741. params=None,
  742. slots_idle=None,
  743. slots_processing=None,
  744. expected_slots=None):
  745. if context.debug:
  746. print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
  747. interval = 0.5
  748. counter = 0
  749. async with aiohttp.ClientSession() as session:
  750. while True:
  751. async with await session.get(f'{base_url}/health', params=params) as health_response:
  752. status_code = health_response.status
  753. health = await health_response.json()
  754. if context.debug:
  755. print(f"HEALTH - response for expected health status='{expected_health_status}' on "
  756. f"'{base_url}/health'?{params} is {health}\n")
  757. if (status_code == expected_http_status_code
  758. and health['status'] == expected_health_status
  759. and (slots_idle is None or health['slots_idle'] == slots_idle)
  760. and (slots_processing is None or health['slots_processing'] == slots_processing)):
  761. if expected_slots is not None:
  762. assert_slots_status(health['slots'], expected_slots)
  763. return
  764. if (status_code == expected_http_status_code
  765. and health['status'] == expected_health_status
  766. and (slots_idle is None or health['slots_idle'] == slots_idle)
  767. and (slots_processing is None or health['slots_processing'] == slots_processing)):
  768. if expected_slots is not None:
  769. assert_slots_status(health['slots'], expected_slots)
  770. return
  771. await asyncio.sleep(interval)
  772. counter += interval
  773. if counter >= timeout:
  774. # Sometimes health requests are triggered after completions are predicted
  775. if expected_http_status_code == 503:
  776. if len(context.tasks_result) == 0:
  777. print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
  778. " busy health check missed, probably too fast inference\x1b[0m\n")
  779. n_completions = await gather_tasks_results(context)
  780. if n_completions > 0:
  781. return
  782. assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}'
  783. def assert_embeddings(embeddings):
  784. assert len(embeddings) > 0
  785. embeddings_computed = False
  786. for emb in embeddings:
  787. if not isinstance(emb, float):
  788. assert False, f"Bad embeddings: {embeddings}"
  789. if emb != 0:
  790. embeddings_computed = True
  791. assert embeddings_computed, f"Embeddings: {embeddings}"
  792. async def request_slots_status(context, expected_slots):
  793. async with aiohttp.ClientSession() as session:
  794. async with await session.get(f'{context.base_url}/slots') as slots_response:
  795. assert slots_response.status == 200
  796. slots = await slots_response.json()
  797. assert_slots_status(slots, expected_slots)
  798. def assert_slots_status(slots, expected_slots):
  799. assert len(slots) == len(expected_slots)
  800. for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
  801. for key in expected:
  802. assert expected[key] == slot[key], (f"invalid slot {slot_id}"
  803. f" expected[{key}] != slot[{key}]"
  804. f" = {expected[key]} != {slot[key]}")
  805. async def completions_seed(context):
  806. return context.seed if hasattr(context, 'seed') and context.seed is not None \
  807. else context.server_seed if hasattr(context, 'server_seed') else None
  808. def start_server_background(context):
  809. context.server_path = '../../../build/bin/server'
  810. if 'LLAMA_SERVER_BIN_PATH' in os.environ:
  811. context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
  812. server_args = [
  813. '--host', context.server_fqdn,
  814. '--port', context.server_port,
  815. '--model', context.model_file
  816. ]
  817. if context.n_batch:
  818. server_args.extend(['--batch-size', context.n_batch])
  819. if context.n_gpu_layer:
  820. server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
  821. if context.server_continuous_batching:
  822. server_args.append('--cont-batching')
  823. if context.server_embeddings:
  824. server_args.append('--embedding')
  825. if context.server_metrics:
  826. server_args.append('--metrics')
  827. if context.model_alias:
  828. server_args.extend(['--alias', context.model_alias])
  829. if context.n_ctx:
  830. server_args.extend(['--ctx-size', context.n_ctx])
  831. if context.n_slots:
  832. server_args.extend(['--parallel', context.n_slots])
  833. if context.n_server_predict:
  834. server_args.extend(['--n-predict', context.n_server_predict])
  835. if context.server_api_key:
  836. server_args.extend(['--api-key', context.server_api_key])
  837. if context.n_ga:
  838. server_args.extend(['--grp-attn-n', context.n_ga])
  839. if context.n_ga_w:
  840. server_args.extend(['--grp-attn-w', context.n_ga_w])
  841. if context.debug:
  842. server_args.append('--verbose')
  843. if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
  844. server_args.extend(['--log-format', "text"])
  845. print(f"starting server with: {context.server_path} {server_args}\n")
  846. context.server_process = subprocess.Popen(
  847. [str(arg) for arg in [context.server_path, *server_args]],
  848. close_fds=True)
  849. print(f"server pid={context.server_process.pid}")