Przeglądaj źródła

server: continue to update other slots on embedding concurrent request (#5699)

* server: #5655 - continue to update other slots on embedding concurrent request.

* server: tests: add multi users embeddings as fixed

* server: tests: adding OAI compatible embedding concurrent endpoint

* server: tests: adding OAI compatible embedding with multiple inputs
Pierrick Hymbert 1 rok temu
rodzic
commit
9e359a4f47

+ 1 - 1
examples/server/server.cpp

@@ -1836,7 +1836,7 @@ struct llama_server_context
                     send_embedding(slot);
                     send_embedding(slot);
                     slot.release();
                     slot.release();
                     slot.i_batch = -1;
                     slot.i_batch = -1;
-                    return true;
+                    continue;
                 }
                 }
 
 
                 completion_token_output result;
                 completion_token_output result;

+ 1 - 33
examples/server/tests/features/issues.feature

@@ -1,36 +1,4 @@
 # List of ongoing issues
 # List of ongoing issues
 @bug
 @bug
 Feature: Issues
 Feature: Issues
-    # Issue #5655
-  Scenario: Multi users embeddings
-    Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
-    And   a model alias tinyllama-2
-    And   42 as server seed
-    And   64 KV cache size
-    And   2 slots
-    And   continuous batching
-    And   embeddings extraction
-    Then  the server is starting
-    Then  the server is healthy
-
-    Given a prompt:
-      """
-      Write a very long story about AI.
-      """
-    And a prompt:
-      """
-      Write another very long music lyrics.
-      """
-    And a prompt:
-      """
-      Write a very long poem.
-      """
-    And a prompt:
-      """
-      Write a very long joke.
-      """
-    Given concurrent embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated
+  # No confirmed issue at the moment

+ 46 - 0
examples/server/tests/features/parallel.feature

@@ -8,6 +8,7 @@ Feature: Parallel
     And   42 as server seed
     And   42 as server seed
     And   64 KV cache size
     And   64 KV cache size
     And   2 slots
     And   2 slots
+    And   embeddings extraction
     And   continuous batching
     And   continuous batching
     Then  the server is starting
     Then  the server is starting
     Then  the server is healthy
     Then  the server is healthy
@@ -75,3 +76,48 @@ Feature: Parallel
     Then the server is busy
     Then the server is busy
     Then the server is idle
     Then the server is idle
     Then all prompts are predicted
     Then all prompts are predicted
+
+  Scenario: Multi users embeddings
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    Given concurrent embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: Multi users OAI compatibility embeddings
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    And a prompt:
+      """
+      What is the biggest US city ?
+      """
+    And a prompt:
+      """
+      What is the capital of Bulgaria ?
+      """
+    And   a model tinyllama-2
+    Given concurrent OAI embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated

+ 13 - 0
examples/server/tests/features/server.feature

@@ -60,6 +60,19 @@ Feature: llama.cpp server
     """
     """
     Then embeddings are generated
     Then embeddings are generated
 
 
+  Scenario: OAI Embeddings compatibility with multiple inputs
+    Given a model tinyllama-2
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    When an OAI compatible embeddings computation request for multiple inputs
+    Then embeddings are generated
+
 
 
   Scenario: Tokenize / Detokenize
   Scenario: Tokenize / Detokenize
     When tokenizing:
     When tokenizing:

+ 107 - 44
examples/server/tests/features/steps/steps.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import collections
 import json
 import json
 import os
 import os
 import re
 import re
@@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
 @step(u'concurrent completion requests')
 @step(u'concurrent completion requests')
 @async_run_until_complete()
 @async_run_until_complete()
 async def step_concurrent_completion_requests(context):
 async def step_concurrent_completion_requests(context):
-    await concurrent_completion_requests(context,
-                                         request_completion,
-                                         # prompt is inserted automatically
-                                         context.base_url,
-                                         debug=context.debug,
-                                         n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
-                                         server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
-                                         user_api_key=context.user_api_key if hasattr(context,
-                                                                                      'user_api_key') else None)
+    await concurrent_requests(context,
+                              request_completion,
+                              # prompt is inserted automatically
+                              context.base_url,
+                              debug=context.debug,
+                              n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
+                              server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+                              user_api_key=context.user_api_key if hasattr(context,
+                                                                           'user_api_key') else None)
 
 
 
 
 @step(u'concurrent OAI completions requests')
 @step(u'concurrent OAI completions requests')
 @async_run_until_complete
 @async_run_until_complete
 async def step_oai_chat_completions(context):
 async def step_oai_chat_completions(context):
-    await concurrent_completion_requests(context, oai_chat_completions,
-                                         # user_prompt is inserted automatically
-                                         context.system_prompt,
-                                         context.base_url,
-                                         True,  # async_client
-                                         model=context.model
-                                         if hasattr(context, 'model') else None,
-                                         n_predict=context.n_predict
-                                         if hasattr(context, 'n_predict') else None,
-                                         enable_streaming=context.enable_streaming
-                                         if hasattr(context, 'enable_streaming') else None,
-                                         server_seed=context.server_seed
-                                         if hasattr(context, 'server_seed') else None,
-                                         user_api_key=context.user_api_key
-                                         if hasattr(context, 'user_api_key') else None)
+    await concurrent_requests(context, oai_chat_completions,
+                              # user_prompt is inserted automatically
+                              context.system_prompt,
+                              context.base_url,
+                              True,  # async_client
+                              model=context.model
+                              if hasattr(context, 'model') else None,
+                              n_predict=context.n_predict
+                              if hasattr(context, 'n_predict') else None,
+                              enable_streaming=context.enable_streaming
+                              if hasattr(context, 'enable_streaming') else None,
+                              server_seed=context.server_seed
+                              if hasattr(context, 'server_seed') else None,
+                              user_api_key=context.user_api_key
+                              if hasattr(context, 'user_api_key') else None)
 
 
 
 
 @step(u'all prompts are predicted')
 @step(u'all prompts are predicted')
@@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
 @step(u'embeddings are computed for')
 @step(u'embeddings are computed for')
 @async_run_until_complete
 @async_run_until_complete
 async def step_compute_embedding(context):
 async def step_compute_embedding(context):
-    content = context.text
-    base_url = context.base_url
-    context.embeddings = await request_embedding(content, base_url)
+    context.embeddings = await request_embedding(context.text, base_url=context.base_url)
 
 
 
 
 @step(u'embeddings are generated')
 @step(u'embeddings are generated')
 def step_assert_embeddings(context):
 def step_assert_embeddings(context):
-    assert_embeddings(context.embeddings)
+    if len(context.prompts) == 0:
+        assert_embeddings(context.embeddings)
+    else:
+        assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
+                                                                 f"context.prompts={context.prompts}\n"
+                                                                 f"context.embeddings={context.embeddings}")
+        for embedding in context.embeddings:
+            context.prompts.pop()
+            assert_embeddings(embedding)
 
 
 
 
 @step(u'an OAI compatible embeddings computation request for')
 @step(u'an OAI compatible embeddings computation request for')
-def step_oai_compute_embedding(context):
-    openai.api_key = 'nope'  # openai client always expects an api_keu
-    if context.user_api_key is not None:
-        openai.api_key = context.user_api_key
-    openai.api_base = f'{context.base_url}/v1'
-    embeddings = openai.Embedding.create(
-        model=context.model,
-        input=context.text,
-    )
-    context.embeddings = embeddings
+@async_run_until_complete
+async def step_oai_compute_embeddings(context):
+    context.embeddings = await request_oai_embeddings(context.text,
+                                                      base_url=context.base_url,
+                                                      user_api_key=context.user_api_key,
+                                                      model=context.model)
+
+
+@step(u'an OAI compatible embeddings computation request for multiple inputs')
+@async_run_until_complete
+async def step_oai_compute_embeddings_multiple_inputs(context):
+    context.embeddings = await request_oai_embeddings(context.prompts,
+                                                      base_url=context.base_url,
+                                                      user_api_key=context.user_api_key,
+                                                      model=context.model)
 
 
 
 
 @step(u'concurrent embedding requests')
 @step(u'concurrent embedding requests')
 @async_run_until_complete()
 @async_run_until_complete()
 async def step_concurrent_embedding_requests(context):
 async def step_concurrent_embedding_requests(context):
-    await concurrent_completion_requests(context,
-                                         request_embedding,
-                                         # prompt is inserted automatically
-                                         context.base_url)
+    await concurrent_requests(context,
+                              request_embedding,
+                              # prompt is inserted automatically
+                              base_url=context.base_url)
+
+
+@step(u'concurrent OAI embedding requests')
+@async_run_until_complete()
+async def step_concurrent_oai_embedding_requests(context):
+    await concurrent_requests(context,
+                              request_oai_embeddings,
+                              # prompt is inserted automatically
+                              base_url=context.base_url,
+                              async_client=True,
+                              model=context.model)
 
 
 
 
 @step(u'all embeddings are generated')
 @step(u'all embeddings are generated')
@@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
     assert context.options_response.headers[cors_header] == cors_header_value
     assert context.options_response.headers[cors_header] == cors_header_value
 
 
 
 
-async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
+async def concurrent_requests(context, f_completion, *args, **kwargs):
     n_prompts = len(context.prompts)
     n_prompts = len(context.prompts)
     if context.debug:
     if context.debug:
         print(f"starting {n_prompts} concurrent completion requests...")
         print(f"starting {n_prompts} concurrent completion requests...")
@@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
     return completion_response
     return completion_response
 
 
 
 
-async def request_embedding(content, base_url):
+async def request_embedding(content, base_url=None):
     async with aiohttp.ClientSession() as session:
     async with aiohttp.ClientSession() as session:
         async with session.post(f'{base_url}/embedding',
         async with session.post(f'{base_url}/embedding',
                                 json={
                                 json={
@@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
             return response_json['embedding']
             return response_json['embedding']
 
 
 
 
+async def request_oai_embeddings(input,
+                                 base_url=None, user_api_key=None,
+                                 model=None, async_client=False):
+    # openai client always expects an api_key
+    user_api_key = user_api_key if user_api_key is not None else 'nope'
+    if async_client:
+        origin = 'llama.cpp'
+        if user_api_key is not None:
+            headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
+        async with aiohttp.ClientSession() as session:
+            async with session.post(f'{base_url}/v1/embeddings',
+                                    json={
+                                        "input": input,
+                                        "model": model,
+                                    },
+                                    headers=headers) as response:
+                assert response.status == 200, f"received status code not expected: {response.status}"
+                assert response.headers['Access-Control-Allow-Origin'] == origin
+                assert response.headers['Content-Type'] == "application/json; charset=utf-8"
+                response_json = await response.json()
+                assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
+                assert response_json['object'] == 'list'
+                return response_json['data']
+    else:
+        openai.api_key = user_api_key
+        openai.api_base = f'{base_url}/v1'
+        oai_embeddings = openai.Embedding.create(
+            model=model,
+            input=input,
+        )
+
+        if isinstance(input, collections.abc.Sequence):
+            embeddings = []
+            for an_oai_embeddings in oai_embeddings.data:
+                embeddings.append(an_oai_embeddings.embedding)
+        else:
+            embeddings = oai_embeddings.data.embedding
+        return embeddings
+
+
 def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
 def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
     content = completion_response['content']
     content = completion_response['content']
     n_predicted = completion_response['timings']['predicted_n']
     n_predicted = completion_response['timings']['predicted_n']