test_embedding.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.bert_bge_small()
  5. EPSILON = 1e-3
  6. @pytest.fixture(scope="module", autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.bert_bge_small()
  10. def test_embedding_single():
  11. global server
  12. server.start()
  13. res = server.make_request("POST", "/embeddings", data={
  14. "input": "I believe the meaning of life is",
  15. })
  16. assert res.status_code == 200
  17. assert len(res.body['data']) == 1
  18. assert 'embedding' in res.body['data'][0]
  19. assert len(res.body['data'][0]['embedding']) > 1
  20. # make sure embedding vector is normalized
  21. assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
  22. def test_embedding_multiple():
  23. global server
  24. server.start()
  25. res = server.make_request("POST", "/embeddings", data={
  26. "input": [
  27. "I believe the meaning of life is",
  28. "Write a joke about AI from a very long prompt which will not be truncated",
  29. "This is a test",
  30. "This is another test",
  31. ],
  32. })
  33. assert res.status_code == 200
  34. assert len(res.body['data']) == 4
  35. for d in res.body['data']:
  36. assert 'embedding' in d
  37. assert len(d['embedding']) > 1
  38. @pytest.mark.parametrize(
  39. "content,is_multi_prompt",
  40. [
  41. # single prompt
  42. ("string", False),
  43. ([12, 34, 56], False),
  44. ([12, 34, "string", 56, 78], False),
  45. # multiple prompts
  46. (["string1", "string2"], True),
  47. (["string1", [12, 34, 56]], True),
  48. ([[12, 34, 56], [12, 34, 56]], True),
  49. ([[12, 34, 56], [12, "string", 34, 56]], True),
  50. ]
  51. )
  52. def test_embedding_mixed_input(content, is_multi_prompt: bool):
  53. global server
  54. server.start()
  55. res = server.make_request("POST", "/embeddings", data={"content": content})
  56. assert res.status_code == 200
  57. if is_multi_prompt:
  58. assert len(res.body) == len(content)
  59. for d in res.body:
  60. assert 'embedding' in d
  61. assert len(d['embedding']) > 1
  62. else:
  63. assert 'embedding' in res.body
  64. assert len(res.body['embedding']) > 1
  65. def test_embedding_openai_library_single():
  66. global server
  67. server.start()
  68. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  69. res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
  70. assert len(res.data) == 1
  71. assert len(res.data[0].embedding) > 1
  72. def test_embedding_openai_library_multiple():
  73. global server
  74. server.start()
  75. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  76. res = client.embeddings.create(model="text-embedding-3-small", input=[
  77. "I believe the meaning of life is",
  78. "Write a joke about AI from a very long prompt which will not be truncated",
  79. "This is a test",
  80. "This is another test",
  81. ])
  82. assert len(res.data) == 4
  83. for d in res.data:
  84. assert len(d.embedding) > 1
  85. def test_embedding_error_prompt_too_long():
  86. global server
  87. server.start()
  88. res = server.make_request("POST", "/embeddings", data={
  89. "input": "This is a test " * 512,
  90. })
  91. assert res.status_code != 200
  92. assert "too large" in res.body["error"]["message"]
  93. def test_same_prompt_give_same_result():
  94. server.start()
  95. res = server.make_request("POST", "/embeddings", data={
  96. "input": [
  97. "I believe the meaning of life is",
  98. "I believe the meaning of life is",
  99. "I believe the meaning of life is",
  100. "I believe the meaning of life is",
  101. "I believe the meaning of life is",
  102. ],
  103. })
  104. assert res.status_code == 200
  105. assert len(res.body['data']) == 5
  106. for i in range(1, len(res.body['data'])):
  107. v0 = res.body['data'][0]['embedding']
  108. vi = res.body['data'][i]['embedding']
  109. for x, y in zip(v0, vi):
  110. assert abs(x - y) < EPSILON
  111. @pytest.mark.parametrize(
  112. "content,n_tokens",
  113. [
  114. ("I believe the meaning of life is", 9),
  115. ("This is a test", 6),
  116. ]
  117. )
  118. def test_embedding_usage_single(content, n_tokens):
  119. global server
  120. server.start()
  121. res = server.make_request("POST", "/embeddings", data={"input": content})
  122. assert res.status_code == 200
  123. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  124. assert res.body['usage']['prompt_tokens'] == n_tokens
  125. def test_embedding_usage_multiple():
  126. global server
  127. server.start()
  128. res = server.make_request("POST", "/embeddings", data={
  129. "input": [
  130. "I believe the meaning of life is",
  131. "I believe the meaning of life is",
  132. ],
  133. })
  134. assert res.status_code == 200
  135. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  136. assert res.body['usage']['prompt_tokens'] == 2 * 9