|
|
@@ -45,6 +45,35 @@ def test_embedding_multiple():
|
|
|
assert len(d['embedding']) > 1
|
|
|
|
|
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "content,is_multi_prompt",
|
|
|
+ [
|
|
|
+ # single prompt
|
|
|
+ ("string", False),
|
|
|
+ ([12, 34, 56], False),
|
|
|
+ ([12, 34, "string", 56, 78], False),
|
|
|
+ # multiple prompts
|
|
|
+ (["string1", "string2"], True),
|
|
|
+ (["string1", [12, 34, 56]], True),
|
|
|
+ ([[12, 34, 56], [12, 34, 56]], True),
|
|
|
+ ([[12, 34, 56], [12, "string", 34, 56]], True),
|
|
|
+ ]
|
|
|
+)
|
|
|
+def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
|
|
+ global server
|
|
|
+ server.start()
|
|
|
+ res = server.make_request("POST", "/embeddings", data={"content": content})
|
|
|
+ assert res.status_code == 200
|
|
|
+ if is_multi_prompt:
|
|
|
+ assert len(res.body) == len(content)
|
|
|
+ for d in res.body:
|
|
|
+ assert 'embedding' in d
|
|
|
+ assert len(d['embedding']) > 1
|
|
|
+ else:
|
|
|
+ assert 'embedding' in res.body
|
|
|
+ assert len(res.body['embedding']) > 1
|
|
|
+
|
|
|
+
|
|
|
def test_embedding_openai_library_single():
|
|
|
global server
|
|
|
server.start()
|
|
|
@@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
|
|
|
@pytest.mark.parametrize(
|
|
|
"content,n_tokens",
|
|
|
[
|
|
|
- ("I believe the meaning of life is", 7),
|
|
|
- ("This is a test", 4),
|
|
|
+ ("I believe the meaning of life is", 9),
|
|
|
+ ("This is a test", 6),
|
|
|
]
|
|
|
)
|
|
|
def test_embedding_usage_single(content, n_tokens):
|
|
|
@@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
|
|
|
})
|
|
|
assert res.status_code == 200
|
|
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
|
|
- assert res.body['usage']['prompt_tokens'] == 2 * 7
|
|
|
+ assert res.body['usage']['prompt_tokens'] == 2 * 9
|