steps.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. import asyncio
  2. import json
  3. import os
  4. import re
  5. import socket
  6. import subprocess
  7. import time
  8. from contextlib import closing
  9. from re import RegexFlag
  10. import aiohttp
  11. import openai
  12. from behave import step
  13. from behave.api.async_step import async_run_until_complete
  14. @step(u"a server listening on {server_fqdn}:{server_port}")
  15. def step_server_config(context, server_fqdn, server_port):
  16. context.server_fqdn = server_fqdn
  17. context.server_port = int(server_port)
  18. if 'PORT' in os.environ:
  19. context.server_port = int(os.environ['PORT'])
  20. print(f"$PORT set, overriding server port with to {context.server_port}")
  21. context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
  22. context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
  23. context.model_alias = None
  24. context.n_ctx = None
  25. context.n_predict = None
  26. context.n_server_predict = None
  27. context.n_slots = None
  28. context.server_api_key = None
  29. context.server_continuous_batching = False
  30. context.server_embeddings = False
  31. context.server_seed = None
  32. context.user_api_key = None
  33. context.tasks_result = []
  34. context.concurrent_tasks = []
  35. context.prompts = []
  36. @step(u'a model file {model_file}')
  37. def step_model_file(context, model_file):
  38. context.model_file = model_file
  39. @step(u'a model alias {model_alias}')
  40. def step_model_alias(context, model_alias):
  41. context.model_alias = model_alias
  42. @step(u'{seed} as server seed')
  43. def step_seed(context, seed):
  44. context.server_seed = int(seed)
  45. @step(u'{n_ctx} KV cache size')
  46. def step_n_ctx(context, n_ctx):
  47. context.n_ctx = int(n_ctx)
  48. @step(u'{n_slots} slots')
  49. def step_n_slots(context, n_slots):
  50. context.n_slots = int(n_slots)
  51. @step(u'{n_predict} server max tokens to predict')
  52. def step_server_n_predict(context, n_predict):
  53. context.n_server_predict = int(n_predict)
  54. @step(u'continuous batching')
  55. def step_server_continuous_batching(context):
  56. context.server_continuous_batching = True
  57. @step(u'embeddings extraction')
  58. def step_server_embeddings(context):
  59. context.server_embeddings = True
  60. @step(u"the server is starting")
  61. def step_start_server(context):
  62. start_server_background(context)
  63. attempts = 0
  64. while True:
  65. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
  66. result = sock.connect_ex((context.server_fqdn, context.server_port))
  67. if result == 0:
  68. print("\x1b[33;46mserver started!\x1b[0m")
  69. return
  70. attempts += 1
  71. if attempts > 20:
  72. assert False, "server not started"
  73. print(f"waiting for server to start, connect error code = {result}...")
  74. time.sleep(0.1)
  75. @step(u"the server is {expecting_status}")
  76. @async_run_until_complete
  77. async def step_wait_for_the_server_to_be_started(context, expecting_status):
  78. match expecting_status:
  79. case 'healthy':
  80. await wait_for_health_status(context, context.base_url, 200, 'ok')
  81. case 'ready' | 'idle':
  82. await wait_for_health_status(context, context.base_url, 200, 'ok',
  83. params={'fail_on_no_slot': 0, 'include_slots': 0},
  84. slots_idle=context.n_slots,
  85. slots_processing=0,
  86. expected_slots=[{'id': slot_id, 'state': 0}
  87. for slot_id in range(context.n_slots)])
  88. case 'busy':
  89. await wait_for_health_status(context, context.base_url, 503,
  90. 'no slot available',
  91. params={'fail_on_no_slot': 0, 'include_slots': 0},
  92. slots_idle=0,
  93. slots_processing=context.n_slots,
  94. expected_slots=[{'id': slot_id, 'state': 1}
  95. for slot_id in range(context.n_slots)])
  96. case _:
  97. assert False, "unknown status"
  98. @step(u'all slots are {expected_slot_status_string}')
  99. @async_run_until_complete
  100. async def step_all_slots_status(context, expected_slot_status_string):
  101. match expected_slot_status_string:
  102. case 'idle':
  103. expected_slot_status = 0
  104. case 'busy':
  105. expected_slot_status = 1
  106. case _:
  107. assert False, "unknown status"
  108. expected_slots = [{'id': slot_id, 'state': expected_slot_status}
  109. for slot_id in range(context.n_slots)]
  110. await request_slots_status(context, expected_slots)
  111. @step(u'a completion request with {api_error} api error')
  112. @async_run_until_complete
  113. async def step_request_completion(context, api_error):
  114. expect_api_error = api_error == 'raised'
  115. completion = await request_completion(context.prompts.pop(),
  116. context.base_url,
  117. debug=context.debug,
  118. n_predict=context.n_predict,
  119. server_seed=context.server_seed,
  120. expect_api_error=expect_api_error,
  121. user_api_key=context.user_api_key)
  122. context.tasks_result.append(completion)
  123. if context.debug:
  124. print(f"Completion response: {completion}")
  125. if expect_api_error:
  126. assert completion == 401, f"completion must be an 401 status code: {completion}"
  127. @step(u'{predicted_n} tokens are predicted matching {re_content}')
  128. def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
  129. assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
  130. @step(u'{predicted_n} tokens are predicted')
  131. def step_n_tokens_predicted(context, predicted_n):
  132. assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
  133. @step(u'a user prompt {user_prompt}')
  134. def step_user_prompt(context, user_prompt):
  135. context.prompts.append(user_prompt)
  136. @step(u'a system prompt {system_prompt}')
  137. def step_system_prompt(context, system_prompt):
  138. context.system_prompt = system_prompt
  139. @step(u'a model {model}')
  140. def step_model(context, model):
  141. context.model = model
  142. @step(u'{max_tokens} max tokens to predict')
  143. def step_max_tokens(context, max_tokens):
  144. context.n_predict = int(max_tokens)
  145. @step(u'streaming is {enable_streaming}')
  146. def step_streaming(context, enable_streaming):
  147. context.enable_streaming = enable_streaming == 'enabled'
  148. @step(u'a user api key {user_api_key}')
  149. def step_user_api_key(context, user_api_key):
  150. context.user_api_key = user_api_key
  151. @step(u'no user api key')
  152. def step_no_user_api_key(context):
  153. context.user_api_key = None
  154. @step(u'a user api key ')
  155. def step_no_user_api_key_space(context):
  156. context.user_api_key = None
  157. @step(u'a server api key {server_api_key}')
  158. def step_server_api_key(context, server_api_key):
  159. context.server_api_key = server_api_key
  160. @step(u'an OAI compatible chat completions request with {api_error} api error')
  161. @async_run_until_complete
  162. async def step_oai_chat_completions(context, api_error):
  163. if context.debug:
  164. print(f"Submitting OAI compatible completions request...")
  165. expect_api_error = api_error == 'raised'
  166. completion = await oai_chat_completions(context.prompts.pop(),
  167. context.system_prompt,
  168. context.base_url,
  169. False,
  170. model=context.model if hasattr(context, 'model') else None,
  171. n_predict=context.n_predict
  172. if hasattr(context, 'n_predict') else None,
  173. enable_streaming=context.enable_streaming
  174. if hasattr(context, 'enable_streaming') else None,
  175. server_seed=context.server_seed
  176. if hasattr(context, 'server_seed') else None,
  177. user_api_key=context.user_api_key
  178. if hasattr(context, 'user_api_key') else None,
  179. expect_api_error=expect_api_error)
  180. context.tasks_result.append(completion)
  181. if context.debug:
  182. print(f"Completion response: {completion}")
  183. if expect_api_error:
  184. assert completion == 401, f"completion must be an 401 status code: {completion}"
  185. if context.debug:
  186. print(f"Completion response: {completion}")
  187. @step(u'a prompt')
  188. def step_a_prompt(context):
  189. context.prompts.append(context.text)
  190. @step(u'a prompt {prompt}')
  191. def step_a_prompt_prompt(context, prompt):
  192. context.prompts.append(prompt)
  193. @step(u'concurrent completion requests')
  194. @async_run_until_complete()
  195. async def step_concurrent_completion_requests(context):
  196. await concurrent_completion_requests(context,
  197. request_completion,
  198. # prompt is inserted automatically
  199. context.base_url,
  200. debug=context.debug,
  201. n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
  202. server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
  203. user_api_key=context.user_api_key if hasattr(context,
  204. 'user_api_key') else None)
  205. @step(u'concurrent OAI completions requests')
  206. @async_run_until_complete
  207. async def step_oai_chat_completions(context):
  208. await concurrent_completion_requests(context, oai_chat_completions,
  209. # user_prompt is inserted automatically
  210. context.system_prompt,
  211. context.base_url,
  212. True, # async_client
  213. model=context.model
  214. if hasattr(context, 'model') else None,
  215. n_predict=context.n_predict
  216. if hasattr(context, 'n_predict') else None,
  217. enable_streaming=context.enable_streaming
  218. if hasattr(context, 'enable_streaming') else None,
  219. server_seed=context.server_seed
  220. if hasattr(context, 'server_seed') else None,
  221. user_api_key=context.user_api_key
  222. if hasattr(context, 'user_api_key') else None)
  223. @step(u'all prompts are predicted')
  224. @async_run_until_complete
  225. async def step_all_prompts_are_predicted(context):
  226. await all_prompts_are_predicted(context)
  227. @step(u'all prompts are predicted with {n_predict} tokens')
  228. @async_run_until_complete
  229. async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
  230. expected_predicted_n = int(n_predict)
  231. await all_prompts_are_predicted(context, expected_predicted_n)
  232. async def all_prompts_are_predicted(context, expected_predicted_n=None):
  233. n_completions = await gather_tasks_results(context)
  234. assert n_completions > 0
  235. for i in range(n_completions):
  236. assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
  237. assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
  238. @step(u'embeddings are computed for')
  239. @async_run_until_complete
  240. async def step_compute_embedding(context):
  241. content = context.text
  242. base_url = context.base_url
  243. context.embeddings = await request_embedding(content, base_url)
  244. @step(u'embeddings are generated')
  245. def step_assert_embeddings(context):
  246. assert_embeddings(context.embeddings)
  247. @step(u'an OAI compatible embeddings computation request for')
  248. def step_oai_compute_embedding(context):
  249. openai.api_key = 'nope' # openai client always expects an api_keu
  250. if context.user_api_key is not None:
  251. openai.api_key = context.user_api_key
  252. openai.api_base = f'{context.base_url}/v1'
  253. embeddings = openai.Embedding.create(
  254. model=context.model,
  255. input=context.text,
  256. )
  257. context.embeddings = embeddings
  258. @step(u'concurrent embedding requests')
  259. @async_run_until_complete()
  260. async def step_concurrent_embedding_requests(context):
  261. await concurrent_completion_requests(context,
  262. request_embedding,
  263. # prompt is inserted automatically
  264. context.base_url)
  265. @step(u'all embeddings are generated')
  266. @async_run_until_complete()
  267. async def all_embeddings_are_generated(context):
  268. n_embedding_requests = await gather_tasks_results(context)
  269. assert n_embedding_requests > 0
  270. for i in range(n_embedding_requests):
  271. assert_embeddings(context.tasks_result.pop())
  272. @step(u'tokenizing')
  273. @async_run_until_complete
  274. async def step_tokenize(context):
  275. context.tokenized_text = context.text
  276. async with aiohttp.ClientSession() as session:
  277. async with session.post(f'{context.base_url}/tokenize',
  278. json={
  279. "content": context.tokenized_text,
  280. }) as response:
  281. assert response.status == 200
  282. tokenize_json = await response.json()
  283. context.tokens = tokenize_json['tokens']
  284. @step(u'tokens can be detokenize')
  285. @async_run_until_complete
  286. async def step_detokenize(context):
  287. assert len(context.tokens) > 0
  288. async with aiohttp.ClientSession() as session:
  289. async with session.post(f'{context.base_url}/detokenize',
  290. json={
  291. "tokens": context.tokens,
  292. }) as response:
  293. assert response.status == 200
  294. detokenize_json = await response.json()
  295. # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15
  296. assert context.tokenized_text == detokenize_json['content'].strip()
  297. @step(u'an OPTIONS request is sent from {origin}')
  298. @async_run_until_complete
  299. async def step_options_request(context, origin):
  300. async with aiohttp.ClientSession() as session:
  301. async with session.options(f'{context.base_url}/v1/chat/completions',
  302. headers={"Origin": origin}) as response:
  303. assert response.status == 200
  304. context.options_response = response
  305. @step(u'CORS header {cors_header} is set to {cors_header_value}')
  306. def step_check_options_header_value(context, cors_header, cors_header_value):
  307. assert context.options_response.headers[cors_header] == cors_header_value
  308. async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
  309. n_prompts = len(context.prompts)
  310. if context.debug:
  311. print(f"starting {n_prompts} concurrent completion requests...")
  312. assert n_prompts > 0
  313. for prompt_no in range(n_prompts):
  314. shifted_args = [context.prompts.pop(), *args]
  315. context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
  316. await asyncio.sleep(0.1)
  317. async def request_completion(prompt,
  318. base_url,
  319. debug=False,
  320. n_predict=None,
  321. server_seed=None,
  322. expect_api_error=None,
  323. user_api_key=None):
  324. if debug:
  325. print(f"Sending completion request: {prompt}")
  326. origin = "my.super.domain"
  327. headers = {
  328. 'Origin': origin
  329. }
  330. if user_api_key is not None:
  331. if debug:
  332. print(f"Set user_api_key: {user_api_key}")
  333. headers['Authorization'] = f'Bearer {user_api_key}'
  334. async with aiohttp.ClientSession() as session:
  335. async with session.post(f'{base_url}/completion',
  336. json={
  337. "prompt": prompt,
  338. "n_predict": int(n_predict) if n_predict is not None else -1,
  339. "seed": server_seed if server_seed is not None else 42
  340. },
  341. headers=headers) as response:
  342. if expect_api_error is None or not expect_api_error:
  343. assert response.status == 200
  344. assert response.headers['Access-Control-Allow-Origin'] == origin
  345. return await response.json()
  346. else:
  347. return response.status
  348. async def oai_chat_completions(user_prompt,
  349. system_prompt,
  350. base_url,
  351. async_client,
  352. debug=False,
  353. model=None,
  354. n_predict=None,
  355. enable_streaming=None,
  356. server_seed=None,
  357. user_api_key=None,
  358. expect_api_error=None):
  359. if debug:
  360. print(f"Sending OAI Chat completions request: {user_prompt}")
  361. # openai client always expects an api key
  362. user_api_key = user_api_key if user_api_key is not None else 'nope'
  363. seed = server_seed if server_seed is not None else 42
  364. enable_streaming = enable_streaming if enable_streaming is not None else False
  365. payload = {
  366. "messages": [
  367. {
  368. "role": "system",
  369. "content": system_prompt,
  370. },
  371. {
  372. "role": "user",
  373. "content": user_prompt,
  374. }
  375. ],
  376. "model": model,
  377. "max_tokens": n_predict,
  378. "stream": enable_streaming,
  379. "seed": seed
  380. }
  381. completion_response = {
  382. 'content': '',
  383. 'timings': {
  384. 'predicted_n': 0
  385. }
  386. }
  387. if async_client:
  388. origin = 'llama.cpp'
  389. headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
  390. async with aiohttp.ClientSession() as session:
  391. async with session.post(f'{base_url}/v1/chat/completions',
  392. json=payload,
  393. headers=headers) as response:
  394. if enable_streaming:
  395. assert response.status == 200
  396. assert response.headers['Access-Control-Allow-Origin'] == origin
  397. assert response.headers['Content-Type'] == "text/event-stream"
  398. event_received = True
  399. while event_received:
  400. event_received = False
  401. async for line_in_bytes in response.content:
  402. line = line_in_bytes.decode('utf8')
  403. line = line.rstrip('\n').rstrip('\r')
  404. if line == '':
  405. continue
  406. event_data = line.split(': ', 1)
  407. assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```'
  408. chunk_raw = event_data[1]
  409. chunk = json.loads(chunk_raw)
  410. assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```"
  411. delta = chunk['choices'][0]['delta']
  412. if 'content' in delta:
  413. completion_response['content'] += delta['content']
  414. completion_response['timings']['predicted_n'] += 1
  415. else:
  416. if expect_api_error is None or not expect_api_error:
  417. assert response.status == 200
  418. assert response.headers['Access-Control-Allow-Origin'] == origin
  419. assert response.headers['Content-Type'] == "application/json; charset=utf-8"
  420. chat_completion_raw = await response.json()
  421. completion_response = {
  422. 'content': chat_completion_raw['choices'][0]['message'],
  423. 'timings': {
  424. 'predicted_n': chat_completion_raw['usage']['completion_tokens']
  425. }
  426. }
  427. else:
  428. return response.status
  429. else:
  430. try:
  431. openai.api_key = user_api_key
  432. openai.api_base = f'{base_url}/v1/chat'
  433. chat_completion = openai.Completion.create(
  434. messages=payload['messages'],
  435. model=model,
  436. max_tokens=n_predict,
  437. stream=enable_streaming,
  438. seed=seed
  439. )
  440. except openai.error.APIError as e:
  441. if expect_api_error is not None and expect_api_error:
  442. return 401
  443. else:
  444. assert False, f'error raised: {e}'
  445. if enable_streaming:
  446. for chunk in chat_completion:
  447. assert len(chunk.choices) == 1
  448. delta = chunk.choices[0].delta
  449. if 'content' in delta:
  450. completion_response['content'] += delta['content']
  451. completion_response['timings']['predicted_n'] += 1
  452. else:
  453. assert len(chat_completion.choices) == 1
  454. completion_response = {
  455. 'content': chat_completion.choices[0].message.content,
  456. 'timings': {
  457. 'predicted_n': chat_completion.usage.completion_tokens
  458. }
  459. }
  460. if debug:
  461. print("OAI response formatted to llama.cpp:", completion_response)
  462. return completion_response
  463. async def request_embedding(content, base_url):
  464. async with aiohttp.ClientSession() as session:
  465. async with session.post(f'{base_url}/embedding',
  466. json={
  467. "content": content,
  468. }) as response:
  469. assert response.status == 200
  470. response_json = await response.json()
  471. return response_json['embedding']
  472. def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
  473. content = completion_response['content']
  474. n_predicted = completion_response['timings']['predicted_n']
  475. assert len(content) > 0, "no token predicted"
  476. if expected_predicted_n is not None:
  477. assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
  478. f' {n_predicted} <> {expected_predicted_n}')
  479. if re_content is not None:
  480. re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
  481. assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
  482. f'invalid tokens predicted:'
  483. f' ```\n{content}\n``` do not match /{re_content}/')
  484. async def gather_tasks_results(context):
  485. n_tasks = len(context.concurrent_tasks)
  486. if context.debug:
  487. print(f"Waiting for all {n_tasks} tasks results...")
  488. for task_no in range(n_tasks):
  489. context.tasks_result.append(await context.concurrent_tasks.pop())
  490. n_completions = len(context.tasks_result)
  491. return n_completions
  492. async def wait_for_health_status(context,
  493. base_url,
  494. expected_http_status_code,
  495. expected_health_status,
  496. params=None,
  497. slots_idle=None,
  498. slots_processing=None,
  499. expected_slots=None):
  500. if context.debug:
  501. print(f"Starting checking for health for expected_health_status={expected_health_status}")
  502. timeout = 3 # seconds
  503. interval = 0.5
  504. counter = 0
  505. async with aiohttp.ClientSession() as session:
  506. while True:
  507. async with await session.get(f'{base_url}/health', params=params) as health_response:
  508. status_code = health_response.status
  509. health = await health_response.json()
  510. if context.debug:
  511. print(f"HEALTH - response for expected health status='{expected_health_status}' on "
  512. f"'{base_url}/health'?{params} is {health}")
  513. if (status_code == expected_http_status_code
  514. and health['status'] == expected_health_status
  515. and (slots_idle is None or health['slots_idle'] == slots_idle)
  516. and (slots_processing is None or health['slots_processing'] == slots_processing)):
  517. if expected_slots is not None:
  518. assert_slots_status(health['slots'], expected_slots)
  519. return
  520. if (status_code == expected_http_status_code
  521. and health['status'] == expected_health_status
  522. and (slots_idle is None or health['slots_idle'] == slots_idle)
  523. and (slots_processing is None or health['slots_processing'] == slots_processing)):
  524. if expected_slots is not None:
  525. assert_slots_status(health['slots'], expected_slots)
  526. return
  527. await asyncio.sleep(interval)
  528. counter += interval
  529. if counter >= timeout:
  530. # Sometimes health requests are triggered after completions are predicted
  531. if expected_http_status_code == 503:
  532. if len(context.tasks_result) == 0:
  533. print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
  534. " busy health check missed, probably too fast inference\x1b[0m")
  535. n_completions = await gather_tasks_results(context)
  536. if n_completions > 0:
  537. return
  538. assert False, 'timeout exceeded'
  539. def assert_embeddings(embeddings):
  540. assert len(embeddings) > 0
  541. embeddings_computed = False
  542. for emb in embeddings:
  543. if emb != 0:
  544. embeddings_computed = True
  545. assert embeddings_computed, f"Embeddings: {embeddings}"
  546. async def request_slots_status(context, expected_slots):
  547. async with aiohttp.ClientSession() as session:
  548. async with await session.get(f'{context.base_url}/slots') as slots_response:
  549. assert slots_response.status == 200
  550. slots = await slots_response.json()
  551. assert_slots_status(slots, expected_slots)
  552. def assert_slots_status(slots, expected_slots):
  553. assert len(slots) == len(expected_slots)
  554. for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
  555. for key in expected:
  556. assert expected[key] == slot[key], (f"invalid slot {slot_id}"
  557. f" expected[{key}] != slot[{key}]"
  558. f" = {expected[key]} != {slot[key]}")
  559. def start_server_background(context):
  560. context.server_path = '../../../build/bin/server'
  561. if 'LLAMA_SERVER_BIN_PATH' in os.environ:
  562. context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
  563. server_args = [
  564. '--host', context.server_fqdn,
  565. '--port', context.server_port,
  566. '--model', context.model_file
  567. ]
  568. if context.server_continuous_batching:
  569. server_args.append('--cont-batching')
  570. if context.server_embeddings:
  571. server_args.append('--embedding')
  572. if context.model_alias is not None:
  573. server_args.extend(['--alias', context.model_alias])
  574. if context.n_ctx is not None:
  575. server_args.extend(['--ctx-size', context.n_ctx])
  576. if context.n_slots is not None:
  577. server_args.extend(['--parallel', context.n_slots])
  578. if context.n_server_predict is not None:
  579. server_args.extend(['--n-predict', context.n_server_predict])
  580. if context.server_api_key is not None:
  581. server_args.extend(['--api-key', context.server_api_key])
  582. if context.debug:
  583. server_args.append('--verbose')
  584. print(f"starting server with: {context.server_path}", *server_args)
  585. context.server_process = subprocess.Popen(
  586. [str(arg) for arg in [context.server_path, *server_args]],
  587. close_fds=True)
  588. print(f"server pid={context.server_process.pid}")