|
|
@@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
|
|
|
def step_download_hf_model(context, hf_file, hf_repo):
|
|
|
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
|
|
|
if context.debug:
|
|
|
- print(f"model file: {context.model_file}\n")
|
|
|
+ print(f"model file: {context.model_file}")
|
|
|
|
|
|
|
|
|
@step('a model file {model_file}')
|
|
|
@@ -137,9 +137,12 @@ def step_start_server(context):
|
|
|
if 'GITHUB_ACTIONS' in os.environ:
|
|
|
max_attempts *= 2
|
|
|
|
|
|
+ addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
|
|
|
+ family, typ, proto, _, sockaddr = addrs[0]
|
|
|
+
|
|
|
while True:
|
|
|
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
|
|
- result = sock.connect_ex((context.server_fqdn, context.server_port))
|
|
|
+ with closing(socket.socket(family, typ, proto)) as sock:
|
|
|
+ result = sock.connect_ex(sockaddr)
|
|
|
if result == 0:
|
|
|
print("\x1b[33;46mserver started!\x1b[0m")
|
|
|
return
|
|
|
@@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
|
|
|
user_api_key=context.user_api_key)
|
|
|
context.tasks_result.append(completion)
|
|
|
if context.debug:
|
|
|
- print(f"Completion response: {completion}\n")
|
|
|
+ print(f"Completion response: {completion}")
|
|
|
if expect_api_error:
|
|
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
|
|
|
|
|
@@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
|
|
|
prompt += context.prompt_junk_suffix
|
|
|
if context.debug:
|
|
|
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
|
|
|
- print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
|
|
|
+ print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
|
|
|
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
|
|
|
context.n_prompts = len(context.prompts)
|
|
|
|
|
|
@@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
|
|
|
@async_run_until_complete
|
|
|
async def step_oai_chat_completions(context, api_error):
|
|
|
if context.debug:
|
|
|
- print(f"Submitting OAI compatible completions request...\n")
|
|
|
+ print(f"Submitting OAI compatible completions request...")
|
|
|
expect_api_error = api_error == 'raised'
|
|
|
completion = await oai_chat_completions(context.prompts.pop(),
|
|
|
context.system_prompt,
|
|
|
@@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
|
|
|
embedding1 = np.array(embeddings[i])
|
|
|
embedding2 = np.array(embeddings[j])
|
|
|
if context.debug:
|
|
|
- print(f"embedding1: {embedding1[-8:]}\n")
|
|
|
- print(f"embedding2: {embedding2[-8:]}\n")
|
|
|
+ print(f"embedding1: {embedding1[-8:]}")
|
|
|
+ print(f"embedding2: {embedding2[-8:]}")
|
|
|
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
|
|
msg = f"Similarity between {i} and {j}: {similarity:.10f}"
|
|
|
if context.debug:
|
|
|
- print(f"{msg}\n")
|
|
|
+ print(f"{msg}")
|
|
|
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
|
|
|
|
|
|
|
|
|
@@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
|
|
|
metrics_raw = await metrics_response.text()
|
|
|
metric_exported = False
|
|
|
if context.debug:
|
|
|
- print(f"/metrics answer:\n{metrics_raw}\n")
|
|
|
+ print(f"/metrics answer:\n{metrics_raw}")
|
|
|
context.metrics = {}
|
|
|
for metric in parser.text_string_to_metric_families(metrics_raw):
|
|
|
match metric.name:
|
|
|
@@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
|
|
last_match = end
|
|
|
highlighted += content[last_match:]
|
|
|
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
|
|
- print(f"Checking completion response: {highlighted}\n")
|
|
|
+ print(f"Checking completion response: {highlighted}")
|
|
|
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
|
|
|
if expected_predicted_n and expected_predicted_n > 0:
|
|
|
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
|
|
@@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
|
|
async def gather_tasks_results(context):
|
|
|
n_tasks = len(context.concurrent_tasks)
|
|
|
if context.debug:
|
|
|
- print(f"Waiting for all {n_tasks} tasks results...\n")
|
|
|
+ print(f"Waiting for all {n_tasks} tasks results...")
|
|
|
for task_no in range(n_tasks):
|
|
|
context.tasks_result.append(await context.concurrent_tasks.pop())
|
|
|
n_completions = len(context.tasks_result)
|
|
|
@@ -959,7 +962,7 @@ async def wait_for_health_status(context,
|
|
|
slots_processing=None,
|
|
|
expected_slots=None):
|
|
|
if context.debug:
|
|
|
- print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
|
|
|
+ print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
|
|
interval = 0.5
|
|
|
counter = 0
|
|
|
if 'GITHUB_ACTIONS' in os.environ:
|
|
|
@@ -1048,8 +1051,6 @@ def start_server_background(context):
|
|
|
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
|
|
|
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
|
|
|
server_listen_addr = context.server_fqdn
|
|
|
- if os.name == 'nt':
|
|
|
- server_listen_addr = '0.0.0.0'
|
|
|
server_args = [
|
|
|
'--host', server_listen_addr,
|
|
|
'--port', context.server_port,
|
|
|
@@ -1088,7 +1089,7 @@ def start_server_background(context):
|
|
|
server_args.append('--verbose')
|
|
|
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
|
|
|
server_args.extend(['--log-format', "text"])
|
|
|
- print(f"starting server with: {context.server_path} {server_args}\n")
|
|
|
+ print(f"starting server with: {context.server_path} {server_args}")
|
|
|
flags = 0
|
|
|
if 'nt' == os.name:
|
|
|
flags |= subprocess.DETACHED_PROCESS
|