Просмотр исходного кода

server tests : more pythonic process management; fix bare `except:` (#6146)

* server tests : remove seemingly redundant newlines in print()

* server tests : use built-in subprocess features, not os.kill and psutil

* server tests : do not catch e.g. SystemExit; use print_exc

* server tests: handle TimeoutExpired exception

* server tests: fix connect on dual-stack systems

* server: tests: add new tokens regex on windows generated following new repeat penalties default changed in (#6127)

* server: tests: remove the hack on windows since now we get the good socket family

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

---------

Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
Jared Van Bortel 1 год назад
Родитель
Сommit
bd60d82d0c

+ 22 - 52
examples/server/tests/features/environment.py

@@ -5,15 +5,14 @@ import sys
 import time
 import time
 import traceback
 import traceback
 from contextlib import closing
 from contextlib import closing
-
-import psutil
+from subprocess import TimeoutExpired
 
 
 
 
 def before_scenario(context, scenario):
 def before_scenario(context, scenario):
     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
     if context.debug:
     if context.debug:
-        print("DEBUG=ON\n")
-    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
+        print("DEBUG=ON")
+    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
     port = 8080
     port = 8080
     if 'PORT' in os.environ:
     if 'PORT' in os.environ:
         port = int(os.environ['PORT'])
         port = int(os.environ['PORT'])
@@ -27,60 +26,40 @@ def after_scenario(context, scenario):
             return
             return
         if scenario.status == "failed":
         if scenario.status == "failed":
             if 'GITHUB_ACTIONS' in os.environ:
             if 'GITHUB_ACTIONS' in os.environ:
-                print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
+                print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
                 if os.path.isfile('llama.log'):
                 if os.path.isfile('llama.log'):
                     with closing(open('llama.log', 'r')) as f:
                     with closing(open('llama.log', 'r')) as f:
                         for line in f:
                         for line in f:
                             print(line)
                             print(line)
             if not is_server_listening(context.server_fqdn, context.server_port):
             if not is_server_listening(context.server_fqdn, context.server_port):
-                print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n")
+                print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
 
 
-        if not pid_exists(context.server_process.pid):
+        if context.server_process.poll() is not None:
             assert False, f"Server not running pid={context.server_process.pid} ..."
             assert False, f"Server not running pid={context.server_process.pid} ..."
 
 
-        server_graceful_shutdown(context)
+        server_graceful_shutdown(context)  # SIGINT
 
 
-        # Wait few for socket to free up
-        time.sleep(0.05)
+        try:
+            context.server_process.wait(0.5)
+        except TimeoutExpired:
+            print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
+            context.server_process.kill()  # SIGKILL
+            context.server_process.wait()
 
 
-        attempts = 0
-        while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
-            server_kill(context)
+        while is_server_listening(context.server_fqdn, context.server_port):
             time.sleep(0.1)
             time.sleep(0.1)
-            attempts += 1
-            if attempts > 5:
-                server_kill_hard(context)
-    except:
-        exc = sys.exception()
-        print("error in after scenario: \n")
-        print(exc)
-        print("*** print_tb: \n")
-        traceback.print_tb(exc.__traceback__, file=sys.stdout)
+    except Exception:
+        print("ignoring error in after_scenario:")
+        traceback.print_exc(file=sys.stdout)
 
 
 
 
 def server_graceful_shutdown(context):
 def server_graceful_shutdown(context):
-    print(f"shutting down server pid={context.server_process.pid} ...\n")
+    print(f"shutting down server pid={context.server_process.pid} ...")
     if os.name == 'nt':
     if os.name == 'nt':
-        os.kill(context.server_process.pid, signal.CTRL_C_EVENT)
+        interrupt = signal.CTRL_C_EVENT
     else:
     else:
-        os.kill(context.server_process.pid, signal.SIGINT)
-
-
-def server_kill(context):
-    print(f"killing server pid={context.server_process.pid} ...\n")
-    context.server_process.kill()
-
-
-def server_kill_hard(context):
-    pid = context.server_process.pid
-    path = context.server_path
-
-    print(f"Server dangling exits, hard killing force {pid}={path}...\n")
-    try:
-        psutil.Process(pid).kill()
-    except psutil.NoSuchProcess:
-        return False
-    return True
+        interrupt = signal.SIGINT
+    context.server_process.send_signal(interrupt)
 
 
 
 
 def is_server_listening(server_fqdn, server_port):
 def is_server_listening(server_fqdn, server_port):
@@ -88,14 +67,5 @@ def is_server_listening(server_fqdn, server_port):
         result = sock.connect_ex((server_fqdn, server_port))
         result = sock.connect_ex((server_fqdn, server_port))
         _is_server_listening = result == 0
         _is_server_listening = result == 0
         if _is_server_listening:
         if _is_server_listening:
-            print(f"server is listening on {server_fqdn}:{server_port}...\n")
+            print(f"server is listening on {server_fqdn}:{server_port}...")
         return _is_server_listening
         return _is_server_listening
-
-
-def pid_exists(pid):
-    try:
-        psutil.Process(pid)
-    except psutil.NoSuchProcess:
-        return False
-    return True
-

+ 7 - 7
examples/server/tests/features/server.feature

@@ -35,9 +35,9 @@ Feature: llama.cpp server
     And   metric llamacpp:tokens_predicted is <n_predicted>
     And   metric llamacpp:tokens_predicted is <n_predicted>
 
 
     Examples: Prompts
     Examples: Prompts
-      | prompt                                                                    | n_predict | re_content                    | n_prompt | n_predicted | truncated |
-      | I believe the meaning of life is                                          | 8         | (read\|going)+                | 18       | 8           | not       |
-      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids)+ | 46       | 64          | not       |
+      | prompt                                                                    | n_predict | re_content                                  | n_prompt | n_predicted | truncated |
+      | I believe the meaning of life is                                          | 8         | (read\|going)+                              | 18       | 8           | not       |
+      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids\|Anna\|forest)+ | 46       | 64          | not       |
 
 
   Scenario: Completion prompt truncated
   Scenario: Completion prompt truncated
     Given a prompt:
     Given a prompt:
@@ -48,7 +48,7 @@ Feature: llama.cpp server
     Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
     Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
     """
     """
     And   a completion request with no api error
     And   a completion request with no api error
-    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry
+    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
     And   the completion is  truncated
     And   the completion is  truncated
     And   109 prompt tokens are processed
     And   109 prompt tokens are processed
 
 
@@ -65,9 +65,9 @@ Feature: llama.cpp server
     And   the completion is <truncated> truncated
     And   the completion is <truncated> truncated
 
 
     Examples: Prompts
     Examples: Prompts
-      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_prompt | n_predicted | enable_streaming | truncated |
-      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+          | 77       | 8           | disabled         | not       |
-      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird)+ | -1       | 64          | enabled          |           |
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content                        | n_prompt | n_predicted | enable_streaming | truncated |
+      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+                     | 77       | 8           | disabled         | not       |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird\|Annabyear)+ | -1       | 64          | enabled          |           |
 
 
 
 
   Scenario: Tokenize / Detokenize
   Scenario: Tokenize / Detokenize

+ 17 - 16
examples/server/tests/features/steps/steps.py

@@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
 def step_download_hf_model(context, hf_file, hf_repo):
 def step_download_hf_model(context, hf_file, hf_repo):
     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
     if context.debug:
     if context.debug:
-        print(f"model file: {context.model_file}\n")
+        print(f"model file: {context.model_file}")
 
 
 
 
 @step('a model file {model_file}')
 @step('a model file {model_file}')
@@ -137,9 +137,12 @@ def step_start_server(context):
     if 'GITHUB_ACTIONS' in os.environ:
     if 'GITHUB_ACTIONS' in os.environ:
         max_attempts *= 2
         max_attempts *= 2
 
 
+    addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
+    family, typ, proto, _, sockaddr = addrs[0]
+
     while True:
     while True:
-        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
-            result = sock.connect_ex((context.server_fqdn, context.server_port))
+        with closing(socket.socket(family, typ, proto)) as sock:
+            result = sock.connect_ex(sockaddr)
             if result == 0:
             if result == 0:
                 print("\x1b[33;46mserver started!\x1b[0m")
                 print("\x1b[33;46mserver started!\x1b[0m")
                 return
                 return
@@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
                                           user_api_key=context.user_api_key)
                                           user_api_key=context.user_api_key)
     context.tasks_result.append(completion)
     context.tasks_result.append(completion)
     if context.debug:
     if context.debug:
-        print(f"Completion response: {completion}\n")
+        print(f"Completion response: {completion}")
     if expect_api_error:
     if expect_api_error:
         assert completion == 401, f"completion must be an 401 status code: {completion}"
         assert completion == 401, f"completion must be an 401 status code: {completion}"
 
 
@@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
         prompt += context.prompt_junk_suffix
         prompt += context.prompt_junk_suffix
     if context.debug:
     if context.debug:
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
-        print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
+        print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
     context.n_prompts = len(context.prompts)
     context.n_prompts = len(context.prompts)
 
 
@@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
 @async_run_until_complete
 @async_run_until_complete
 async def step_oai_chat_completions(context, api_error):
 async def step_oai_chat_completions(context, api_error):
     if context.debug:
     if context.debug:
-        print(f"Submitting OAI compatible completions request...\n")
+        print(f"Submitting OAI compatible completions request...")
     expect_api_error = api_error == 'raised'
     expect_api_error = api_error == 'raised'
     completion = await oai_chat_completions(context.prompts.pop(),
     completion = await oai_chat_completions(context.prompts.pop(),
                                             context.system_prompt,
                                             context.system_prompt,
@@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
             embedding1 = np.array(embeddings[i])
             embedding1 = np.array(embeddings[i])
             embedding2 = np.array(embeddings[j])
             embedding2 = np.array(embeddings[j])
             if context.debug:
             if context.debug:
-                print(f"embedding1: {embedding1[-8:]}\n")
-                print(f"embedding2: {embedding2[-8:]}\n")
+                print(f"embedding1: {embedding1[-8:]}")
+                print(f"embedding2: {embedding2[-8:]}")
             similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
             similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
             msg = f"Similarity between {i} and {j}: {similarity:.10f}"
             msg = f"Similarity between {i} and {j}: {similarity:.10f}"
             if context.debug:
             if context.debug:
-                print(f"{msg}\n")
+                print(f"{msg}")
             assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
             assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
 
 
 
 
@@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
             metrics_raw = await metrics_response.text()
             metrics_raw = await metrics_response.text()
             metric_exported = False
             metric_exported = False
             if context.debug:
             if context.debug:
-                print(f"/metrics answer:\n{metrics_raw}\n")
+                print(f"/metrics answer:\n{metrics_raw}")
             context.metrics = {}
             context.metrics = {}
             for metric in parser.text_string_to_metric_families(metrics_raw):
             for metric in parser.text_string_to_metric_families(metrics_raw):
                 match metric.name:
                 match metric.name:
@@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
             last_match = end
             last_match = end
         highlighted += content[last_match:]
         highlighted += content[last_match:]
         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
-          print(f"Checking completion response: {highlighted}\n")
+          print(f"Checking completion response: {highlighted}")
         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
     if expected_predicted_n and expected_predicted_n > 0:
     if expected_predicted_n and expected_predicted_n > 0:
         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
@@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
 async def gather_tasks_results(context):
 async def gather_tasks_results(context):
     n_tasks = len(context.concurrent_tasks)
     n_tasks = len(context.concurrent_tasks)
     if context.debug:
     if context.debug:
-        print(f"Waiting for all {n_tasks} tasks results...\n")
+        print(f"Waiting for all {n_tasks} tasks results...")
     for task_no in range(n_tasks):
     for task_no in range(n_tasks):
         context.tasks_result.append(await context.concurrent_tasks.pop())
         context.tasks_result.append(await context.concurrent_tasks.pop())
     n_completions = len(context.tasks_result)
     n_completions = len(context.tasks_result)
@@ -959,7 +962,7 @@ async def wait_for_health_status(context,
                                  slots_processing=None,
                                  slots_processing=None,
                                  expected_slots=None):
                                  expected_slots=None):
     if context.debug:
     if context.debug:
-        print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
+        print(f"Starting checking for health for expected_health_status={expected_health_status}")
     interval = 0.5
     interval = 0.5
     counter = 0
     counter = 0
     if 'GITHUB_ACTIONS' in os.environ:
     if 'GITHUB_ACTIONS' in os.environ:
@@ -1048,8 +1051,6 @@ def start_server_background(context):
     if 'LLAMA_SERVER_BIN_PATH' in os.environ:
     if 'LLAMA_SERVER_BIN_PATH' in os.environ:
         context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
         context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
     server_listen_addr = context.server_fqdn
     server_listen_addr = context.server_fqdn
-    if os.name == 'nt':
-        server_listen_addr = '0.0.0.0'
     server_args = [
     server_args = [
         '--host', server_listen_addr,
         '--host', server_listen_addr,
         '--port', context.server_port,
         '--port', context.server_port,
@@ -1088,7 +1089,7 @@ def start_server_background(context):
         server_args.append('--verbose')
         server_args.append('--verbose')
     if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
     if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
         server_args.extend(['--log-format', "text"])
         server_args.extend(['--log-format', "text"])
-    print(f"starting server with: {context.server_path} {server_args}\n")
+    print(f"starting server with: {context.server_path} {server_args}")
     flags = 0
     flags = 0
     if 'nt' == os.name:
     if 'nt' == os.name:
         flags |= subprocess.DETACHED_PROCESS
         flags |= subprocess.DETACHED_PROCESS

+ 0 - 1
examples/server/tests/requirements.txt

@@ -3,5 +3,4 @@ behave~=1.2.6
 huggingface_hub~=0.20.3
 huggingface_hub~=0.20.3
 numpy~=1.24.4
 numpy~=1.24.4
 openai~=0.25.0
 openai~=0.25.0
-psutil~=5.9.8
 prometheus-client~=0.20.0
 prometheus-client~=0.20.0