|
@@ -14,8 +14,9 @@ def create_server():
|
|
|
|
|
|
|
|
def test_embedding_single():
|
|
def test_embedding_single():
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
"input": "I believe the meaning of life is",
|
|
"input": "I believe the meaning of life is",
|
|
|
})
|
|
})
|
|
|
assert res.status_code == 200
|
|
assert res.status_code == 200
|
|
@@ -29,8 +30,9 @@ def test_embedding_single():
|
|
|
|
|
|
|
|
def test_embedding_multiple():
|
|
def test_embedding_multiple():
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
"input": [
|
|
"input": [
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|
|
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
|
@@ -46,7 +48,7 @@ def test_embedding_multiple():
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
@pytest.mark.parametrize(
|
|
|
- "content,is_multi_prompt",
|
|
|
|
|
|
|
+ "input,is_multi_prompt",
|
|
|
[
|
|
[
|
|
|
# single prompt
|
|
# single prompt
|
|
|
("string", False),
|
|
("string", False),
|
|
@@ -59,25 +61,55 @@ def test_embedding_multiple():
|
|
|
([[12, 34, 56], [12, "string", 34, 56]], True),
|
|
([[12, 34, 56], [12, "string", 34, 56]], True),
|
|
|
]
|
|
]
|
|
|
)
|
|
)
|
|
|
-def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
|
|
|
|
|
|
+def test_embedding_mixed_input(input, is_multi_prompt: bool):
|
|
|
global server
|
|
global server
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={"content": content})
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={"input": input})
|
|
|
assert res.status_code == 200
|
|
assert res.status_code == 200
|
|
|
|
|
+ data = res.body['data']
|
|
|
if is_multi_prompt:
|
|
if is_multi_prompt:
|
|
|
- assert len(res.body) == len(content)
|
|
|
|
|
- for d in res.body:
|
|
|
|
|
|
|
+ assert len(data) == len(input)
|
|
|
|
|
+ for d in data:
|
|
|
assert 'embedding' in d
|
|
assert 'embedding' in d
|
|
|
assert len(d['embedding']) > 1
|
|
assert len(d['embedding']) > 1
|
|
|
else:
|
|
else:
|
|
|
- assert 'embedding' in res.body
|
|
|
|
|
- assert len(res.body['embedding']) > 1
|
|
|
|
|
|
|
+ assert 'embedding' in data[0]
|
|
|
|
|
+ assert len(data[0]['embedding']) > 1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_embedding_pooling_none():
|
|
|
|
|
+ global server
|
|
|
|
|
+ server.pooling = 'none'
|
|
|
|
|
+ server.start()
|
|
|
|
|
+ res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
+ "input": "hello hello hello",
|
|
|
|
|
+ })
|
|
|
|
|
+ assert res.status_code == 200
|
|
|
|
|
+ assert 'embedding' in res.body[0]
|
|
|
|
|
+ assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
|
|
|
|
+
|
|
|
|
|
+ # make sure embedding vector is not normalized
|
|
|
|
|
+ for x in res.body[0]['embedding']:
|
|
|
|
|
+ assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_embedding_pooling_none_oai():
|
|
|
|
|
+ global server
|
|
|
|
|
+ server.pooling = 'none'
|
|
|
|
|
+ server.start()
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
|
|
+ "input": "hello hello hello",
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # /v1/embeddings does not support pooling type 'none'
|
|
|
|
|
+ assert res.status_code == 400
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_embedding_openai_library_single():
|
|
def test_embedding_openai_library_single():
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
|
|
|
|
|
|
+ client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
|
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
|
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
|
|
assert len(res.data) == 1
|
|
assert len(res.data) == 1
|
|
|
assert len(res.data[0].embedding) > 1
|
|
assert len(res.data[0].embedding) > 1
|
|
@@ -85,8 +117,9 @@ def test_embedding_openai_library_single():
|
|
|
|
|
|
|
|
def test_embedding_openai_library_multiple():
|
|
def test_embedding_openai_library_multiple():
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
|
|
|
|
|
|
+ client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
|
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
|
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|
|
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
|
@@ -100,8 +133,9 @@ def test_embedding_openai_library_multiple():
|
|
|
|
|
|
|
|
def test_embedding_error_prompt_too_long():
|
|
def test_embedding_error_prompt_too_long():
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
"input": "This is a test " * 512,
|
|
"input": "This is a test " * 512,
|
|
|
})
|
|
})
|
|
|
assert res.status_code != 200
|
|
assert res.status_code != 200
|
|
@@ -109,8 +143,9 @@ def test_embedding_error_prompt_too_long():
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_same_prompt_give_same_result():
|
|
def test_same_prompt_give_same_result():
|
|
|
|
|
+ server.pooling = 'last'
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
"input": [
|
|
"input": [
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|
|
@@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
|
|
|
def test_embedding_usage_single(content, n_tokens):
|
|
def test_embedding_usage_single(content, n_tokens):
|
|
|
global server
|
|
global server
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={"input": content})
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={"input": content})
|
|
|
assert res.status_code == 200
|
|
assert res.status_code == 200
|
|
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
|
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
|
@@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
|
|
|
def test_embedding_usage_multiple():
|
|
def test_embedding_usage_multiple():
|
|
|
global server
|
|
global server
|
|
|
server.start()
|
|
server.start()
|
|
|
- res = server.make_request("POST", "/embeddings", data={
|
|
|
|
|
|
|
+ res = server.make_request("POST", "/v1/embeddings", data={
|
|
|
"input": [
|
|
"input": [
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|
|
|
"I believe the meaning of life is",
|
|
"I believe the meaning of life is",
|