|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
+import collections
|
|
|
import json
|
|
import json
|
|
|
import os
|
|
import os
|
|
|
import re
|
|
import re
|
|
@@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
|
|
|
@step(u'concurrent completion requests')
|
|
@step(u'concurrent completion requests')
|
|
|
@async_run_until_complete()
|
|
@async_run_until_complete()
|
|
|
async def step_concurrent_completion_requests(context):
|
|
async def step_concurrent_completion_requests(context):
|
|
|
- await concurrent_completion_requests(context,
|
|
|
|
|
- request_completion,
|
|
|
|
|
- # prompt is inserted automatically
|
|
|
|
|
- context.base_url,
|
|
|
|
|
- debug=context.debug,
|
|
|
|
|
- n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
|
|
|
|
- server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
|
|
|
|
- user_api_key=context.user_api_key if hasattr(context,
|
|
|
|
|
- 'user_api_key') else None)
|
|
|
|
|
|
|
+ await concurrent_requests(context,
|
|
|
|
|
+ request_completion,
|
|
|
|
|
+ # prompt is inserted automatically
|
|
|
|
|
+ context.base_url,
|
|
|
|
|
+ debug=context.debug,
|
|
|
|
|
+ n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
|
|
|
|
+ server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
|
|
|
|
+ user_api_key=context.user_api_key if hasattr(context,
|
|
|
|
|
+ 'user_api_key') else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'concurrent OAI completions requests')
|
|
@step(u'concurrent OAI completions requests')
|
|
|
@async_run_until_complete
|
|
@async_run_until_complete
|
|
|
async def step_oai_chat_completions(context):
|
|
async def step_oai_chat_completions(context):
|
|
|
- await concurrent_completion_requests(context, oai_chat_completions,
|
|
|
|
|
- # user_prompt is inserted automatically
|
|
|
|
|
- context.system_prompt,
|
|
|
|
|
- context.base_url,
|
|
|
|
|
- True, # async_client
|
|
|
|
|
- model=context.model
|
|
|
|
|
- if hasattr(context, 'model') else None,
|
|
|
|
|
- n_predict=context.n_predict
|
|
|
|
|
- if hasattr(context, 'n_predict') else None,
|
|
|
|
|
- enable_streaming=context.enable_streaming
|
|
|
|
|
- if hasattr(context, 'enable_streaming') else None,
|
|
|
|
|
- server_seed=context.server_seed
|
|
|
|
|
- if hasattr(context, 'server_seed') else None,
|
|
|
|
|
- user_api_key=context.user_api_key
|
|
|
|
|
- if hasattr(context, 'user_api_key') else None)
|
|
|
|
|
|
|
+ await concurrent_requests(context, oai_chat_completions,
|
|
|
|
|
+ # user_prompt is inserted automatically
|
|
|
|
|
+ context.system_prompt,
|
|
|
|
|
+ context.base_url,
|
|
|
|
|
+ True, # async_client
|
|
|
|
|
+ model=context.model
|
|
|
|
|
+ if hasattr(context, 'model') else None,
|
|
|
|
|
+ n_predict=context.n_predict
|
|
|
|
|
+ if hasattr(context, 'n_predict') else None,
|
|
|
|
|
+ enable_streaming=context.enable_streaming
|
|
|
|
|
+ if hasattr(context, 'enable_streaming') else None,
|
|
|
|
|
+ server_seed=context.server_seed
|
|
|
|
|
+ if hasattr(context, 'server_seed') else None,
|
|
|
|
|
+ user_api_key=context.user_api_key
|
|
|
|
|
+ if hasattr(context, 'user_api_key') else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'all prompts are predicted')
|
|
@step(u'all prompts are predicted')
|
|
@@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
|
|
@step(u'embeddings are computed for')
|
|
@step(u'embeddings are computed for')
|
|
|
@async_run_until_complete
|
|
@async_run_until_complete
|
|
|
async def step_compute_embedding(context):
|
|
async def step_compute_embedding(context):
|
|
|
- content = context.text
|
|
|
|
|
- base_url = context.base_url
|
|
|
|
|
- context.embeddings = await request_embedding(content, base_url)
|
|
|
|
|
|
|
+ context.embeddings = await request_embedding(context.text, base_url=context.base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'embeddings are generated')
|
|
@step(u'embeddings are generated')
|
|
|
def step_assert_embeddings(context):
|
|
def step_assert_embeddings(context):
|
|
|
- assert_embeddings(context.embeddings)
|
|
|
|
|
|
|
+ if len(context.prompts) == 0:
|
|
|
|
|
+ assert_embeddings(context.embeddings)
|
|
|
|
|
+ else:
|
|
|
|
|
+ assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
|
|
|
|
|
+ f"context.prompts={context.prompts}\n"
|
|
|
|
|
+ f"context.embeddings={context.embeddings}")
|
|
|
|
|
+ for embedding in context.embeddings:
|
|
|
|
|
+ context.prompts.pop()
|
|
|
|
|
+ assert_embeddings(embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'an OAI compatible embeddings computation request for')
|
|
@step(u'an OAI compatible embeddings computation request for')
|
|
|
-def step_oai_compute_embedding(context):
|
|
|
|
|
- openai.api_key = 'nope' # openai client always expects an api_keu
|
|
|
|
|
- if context.user_api_key is not None:
|
|
|
|
|
- openai.api_key = context.user_api_key
|
|
|
|
|
- openai.api_base = f'{context.base_url}/v1'
|
|
|
|
|
- embeddings = openai.Embedding.create(
|
|
|
|
|
- model=context.model,
|
|
|
|
|
- input=context.text,
|
|
|
|
|
- )
|
|
|
|
|
- context.embeddings = embeddings
|
|
|
|
|
|
|
+@async_run_until_complete
|
|
|
|
|
+async def step_oai_compute_embeddings(context):
|
|
|
|
|
+ context.embeddings = await request_oai_embeddings(context.text,
|
|
|
|
|
+ base_url=context.base_url,
|
|
|
|
|
+ user_api_key=context.user_api_key,
|
|
|
|
|
+ model=context.model)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@step(u'an OAI compatible embeddings computation request for multiple inputs')
|
|
|
|
|
+@async_run_until_complete
|
|
|
|
|
+async def step_oai_compute_embeddings_multiple_inputs(context):
|
|
|
|
|
+ context.embeddings = await request_oai_embeddings(context.prompts,
|
|
|
|
|
+ base_url=context.base_url,
|
|
|
|
|
+ user_api_key=context.user_api_key,
|
|
|
|
|
+ model=context.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'concurrent embedding requests')
|
|
@step(u'concurrent embedding requests')
|
|
|
@async_run_until_complete()
|
|
@async_run_until_complete()
|
|
|
async def step_concurrent_embedding_requests(context):
|
|
async def step_concurrent_embedding_requests(context):
|
|
|
- await concurrent_completion_requests(context,
|
|
|
|
|
- request_embedding,
|
|
|
|
|
- # prompt is inserted automatically
|
|
|
|
|
- context.base_url)
|
|
|
|
|
|
|
+ await concurrent_requests(context,
|
|
|
|
|
+ request_embedding,
|
|
|
|
|
+ # prompt is inserted automatically
|
|
|
|
|
+ base_url=context.base_url)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@step(u'concurrent OAI embedding requests')
|
|
|
|
|
+@async_run_until_complete()
|
|
|
|
|
+async def step_concurrent_oai_embedding_requests(context):
|
|
|
|
|
+ await concurrent_requests(context,
|
|
|
|
|
+ request_oai_embeddings,
|
|
|
|
|
+ # prompt is inserted automatically
|
|
|
|
|
+ base_url=context.base_url,
|
|
|
|
|
+ async_client=True,
|
|
|
|
|
+ model=context.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
@step(u'all embeddings are generated')
|
|
@step(u'all embeddings are generated')
|
|
@@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
|
|
assert context.options_response.headers[cors_header] == cors_header_value
|
|
assert context.options_response.headers[cors_header] == cors_header_value
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
|
|
|
|
|
|
+async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|
|
n_prompts = len(context.prompts)
|
|
n_prompts = len(context.prompts)
|
|
|
if context.debug:
|
|
if context.debug:
|
|
|
print(f"starting {n_prompts} concurrent completion requests...")
|
|
print(f"starting {n_prompts} concurrent completion requests...")
|
|
@@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
|
|
|
return completion_response
|
|
return completion_response
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def request_embedding(content, base_url):
|
|
|
|
|
|
|
+async def request_embedding(content, base_url=None):
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.post(f'{base_url}/embedding',
|
|
async with session.post(f'{base_url}/embedding',
|
|
|
json={
|
|
json={
|
|
@@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
|
|
|
return response_json['embedding']
|
|
return response_json['embedding']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+async def request_oai_embeddings(input,
|
|
|
|
|
+ base_url=None, user_api_key=None,
|
|
|
|
|
+ model=None, async_client=False):
|
|
|
|
|
+ # openai client always expects an api_key
|
|
|
|
|
+ user_api_key = user_api_key if user_api_key is not None else 'nope'
|
|
|
|
|
+ if async_client:
|
|
|
|
|
+ origin = 'llama.cpp'
|
|
|
|
|
+ if user_api_key is not None:
|
|
|
|
|
+ headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
|
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
|
|
+ async with session.post(f'{base_url}/v1/embeddings',
|
|
|
|
|
+ json={
|
|
|
|
|
+ "input": input,
|
|
|
|
|
+ "model": model,
|
|
|
|
|
+ },
|
|
|
|
|
+ headers=headers) as response:
|
|
|
|
|
+ assert response.status == 200, f"received status code not expected: {response.status}"
|
|
|
|
|
+ assert response.headers['Access-Control-Allow-Origin'] == origin
|
|
|
|
|
+ assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
|
|
|
|
+ response_json = await response.json()
|
|
|
|
|
+ assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
|
|
|
|
+ assert response_json['object'] == 'list'
|
|
|
|
|
+ return response_json['data']
|
|
|
|
|
+ else:
|
|
|
|
|
+ openai.api_key = user_api_key
|
|
|
|
|
+ openai.api_base = f'{base_url}/v1'
|
|
|
|
|
+ oai_embeddings = openai.Embedding.create(
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ input=input,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if isinstance(input, collections.abc.Sequence):
|
|
|
|
|
+ embeddings = []
|
|
|
|
|
+ for an_oai_embeddings in oai_embeddings.data:
|
|
|
|
|
+ embeddings.append(an_oai_embeddings.embedding)
|
|
|
|
|
+ else:
|
|
|
|
|
+ embeddings = oai_embeddings.data.embedding
|
|
|
|
|
+ return embeddings
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
|
|
content = completion_response['content']
|
|
content = completion_response['content']
|
|
|
n_predicted = completion_response['timings']['predicted_n']
|
|
n_predicted = completion_response['timings']['predicted_n']
|