test_embedding.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.bert_bge_small()
  5. EPSILON = 1e-3
  6. @pytest.fixture(scope="module", autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.bert_bge_small()
  10. def test_embedding_single():
  11. global server
  12. server.pooling = 'last'
  13. server.start()
  14. res = server.make_request("POST", "/v1/embeddings", data={
  15. "input": "I believe the meaning of life is",
  16. })
  17. assert res.status_code == 200
  18. assert len(res.body['data']) == 1
  19. assert 'embedding' in res.body['data'][0]
  20. assert len(res.body['data'][0]['embedding']) > 1
  21. # make sure embedding vector is normalized
  22. assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
  23. def test_embedding_multiple():
  24. global server
  25. server.pooling = 'last'
  26. server.start()
  27. res = server.make_request("POST", "/v1/embeddings", data={
  28. "input": [
  29. "I believe the meaning of life is",
  30. "Write a joke about AI from a very long prompt which will not be truncated",
  31. "This is a test",
  32. "This is another test",
  33. ],
  34. })
  35. assert res.status_code == 200
  36. assert len(res.body['data']) == 4
  37. for d in res.body['data']:
  38. assert 'embedding' in d
  39. assert len(d['embedding']) > 1
  40. @pytest.mark.parametrize(
  41. "input,is_multi_prompt",
  42. [
  43. # do not crash on empty input
  44. ("", False),
  45. # single prompt
  46. ("string", False),
  47. ([12, 34, 56], False),
  48. ([12, 34, "string", 56, 78], False),
  49. # multiple prompts
  50. (["string1", "string2"], True),
  51. (["string1", [12, 34, 56]], True),
  52. ([[12, 34, 56], [12, 34, 56]], True),
  53. ([[12, 34, 56], [12, "string", 34, 56]], True),
  54. ]
  55. )
  56. def test_embedding_mixed_input(input, is_multi_prompt: bool):
  57. global server
  58. server.start()
  59. res = server.make_request("POST", "/v1/embeddings", data={"input": input})
  60. assert res.status_code == 200
  61. data = res.body['data']
  62. if is_multi_prompt:
  63. assert len(data) == len(input)
  64. for d in data:
  65. assert 'embedding' in d
  66. assert len(d['embedding']) > 1
  67. else:
  68. assert 'embedding' in data[0]
  69. assert len(data[0]['embedding']) > 1
  70. def test_embedding_pooling_none():
  71. global server
  72. server.pooling = 'none'
  73. server.start()
  74. res = server.make_request("POST", "/embeddings", data={
  75. "input": "hello hello hello",
  76. })
  77. assert res.status_code == 200
  78. assert 'embedding' in res.body[0]
  79. assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
  80. # make sure embedding vector is not normalized
  81. for x in res.body[0]['embedding']:
  82. assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
  83. def test_embedding_pooling_none_oai():
  84. global server
  85. server.pooling = 'none'
  86. server.start()
  87. res = server.make_request("POST", "/v1/embeddings", data={
  88. "input": "hello hello hello",
  89. })
  90. # /v1/embeddings does not support pooling type 'none'
  91. assert res.status_code == 400
  92. assert "error" in res.body
  93. def test_embedding_openai_library_single():
  94. global server
  95. server.pooling = 'last'
  96. server.start()
  97. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  98. res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
  99. assert len(res.data) == 1
  100. assert len(res.data[0].embedding) > 1
  101. def test_embedding_openai_library_multiple():
  102. global server
  103. server.pooling = 'last'
  104. server.start()
  105. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  106. res = client.embeddings.create(model="text-embedding-3-small", input=[
  107. "I believe the meaning of life is",
  108. "Write a joke about AI from a very long prompt which will not be truncated",
  109. "This is a test",
  110. "This is another test",
  111. ])
  112. assert len(res.data) == 4
  113. for d in res.data:
  114. assert len(d.embedding) > 1
  115. def test_embedding_error_prompt_too_long():
  116. global server
  117. server.pooling = 'last'
  118. server.start()
  119. res = server.make_request("POST", "/v1/embeddings", data={
  120. "input": "This is a test " * 512,
  121. })
  122. assert res.status_code != 200
  123. assert "too large" in res.body["error"]["message"]
  124. def test_same_prompt_give_same_result():
  125. server.pooling = 'last'
  126. server.start()
  127. res = server.make_request("POST", "/v1/embeddings", data={
  128. "input": [
  129. "I believe the meaning of life is",
  130. "I believe the meaning of life is",
  131. "I believe the meaning of life is",
  132. "I believe the meaning of life is",
  133. "I believe the meaning of life is",
  134. ],
  135. })
  136. assert res.status_code == 200
  137. assert len(res.body['data']) == 5
  138. for i in range(1, len(res.body['data'])):
  139. v0 = res.body['data'][0]['embedding']
  140. vi = res.body['data'][i]['embedding']
  141. for x, y in zip(v0, vi):
  142. assert abs(x - y) < EPSILON
  143. @pytest.mark.parametrize(
  144. "content,n_tokens",
  145. [
  146. ("I believe the meaning of life is", 9),
  147. ("This is a test", 6),
  148. ]
  149. )
  150. def test_embedding_usage_single(content, n_tokens):
  151. global server
  152. server.start()
  153. res = server.make_request("POST", "/v1/embeddings", data={"input": content})
  154. assert res.status_code == 200
  155. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  156. assert res.body['usage']['prompt_tokens'] == n_tokens
  157. def test_embedding_usage_multiple():
  158. global server
  159. server.start()
  160. res = server.make_request("POST", "/v1/embeddings", data={
  161. "input": [
  162. "I believe the meaning of life is",
  163. "I believe the meaning of life is",
  164. ],
  165. })
  166. assert res.status_code == 200
  167. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  168. assert res.body['usage']['prompt_tokens'] == 2 * 9