1
0

test_embedding.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import base64
  2. import struct
  3. import pytest
  4. from openai import OpenAI
  5. from utils import *
  6. server = ServerPreset.bert_bge_small()
  7. EPSILON = 1e-3
  8. @pytest.fixture(autouse=True)
  9. def create_server():
  10. global server
  11. server = ServerPreset.bert_bge_small()
  12. def test_embedding_single():
  13. global server
  14. server.pooling = 'last'
  15. server.start()
  16. res = server.make_request("POST", "/v1/embeddings", data={
  17. "input": "I believe the meaning of life is",
  18. })
  19. assert res.status_code == 200
  20. assert len(res.body['data']) == 1
  21. assert 'embedding' in res.body['data'][0]
  22. assert len(res.body['data'][0]['embedding']) > 1
  23. # make sure embedding vector is normalized
  24. assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
  25. def test_embedding_multiple():
  26. global server
  27. server.pooling = 'last'
  28. server.start()
  29. res = server.make_request("POST", "/v1/embeddings", data={
  30. "input": [
  31. "I believe the meaning of life is",
  32. "Write a joke about AI from a very long prompt which will not be truncated",
  33. "This is a test",
  34. "This is another test",
  35. ],
  36. })
  37. assert res.status_code == 200
  38. assert len(res.body['data']) == 4
  39. for d in res.body['data']:
  40. assert 'embedding' in d
  41. assert len(d['embedding']) > 1
  42. def test_embedding_multiple_with_fa():
  43. server = ServerPreset.bert_bge_small_with_fa()
  44. server.pooling = 'last'
  45. server.start()
  46. # one of these should trigger the FA branch (i.e. context size % 256 == 0)
  47. res = server.make_request("POST", "/v1/embeddings", data={
  48. "input": [
  49. "a "*253,
  50. "b "*254,
  51. "c "*255,
  52. "d "*256,
  53. ],
  54. })
  55. assert res.status_code == 200
  56. assert len(res.body['data']) == 4
  57. for d in res.body['data']:
  58. assert 'embedding' in d
  59. assert len(d['embedding']) > 1
  60. @pytest.mark.parametrize(
  61. "input,is_multi_prompt",
  62. [
  63. # do not crash on empty input
  64. ("", False),
  65. # single prompt
  66. ("string", False),
  67. ([12, 34, 56], False),
  68. ([12, 34, "string", 56, 78], False),
  69. # multiple prompts
  70. (["string1", "string2"], True),
  71. (["string1", [12, 34, 56]], True),
  72. ([[12, 34, 56], [12, 34, 56]], True),
  73. ([[12, 34, 56], [12, "string", 34, 56]], True),
  74. ]
  75. )
  76. def test_embedding_mixed_input(input, is_multi_prompt: bool):
  77. global server
  78. server.start()
  79. res = server.make_request("POST", "/v1/embeddings", data={"input": input})
  80. assert res.status_code == 200
  81. data = res.body['data']
  82. if is_multi_prompt:
  83. assert len(data) == len(input)
  84. for d in data:
  85. assert 'embedding' in d
  86. assert len(d['embedding']) > 1
  87. else:
  88. assert 'embedding' in data[0]
  89. assert len(data[0]['embedding']) > 1
  90. def test_embedding_pooling_none():
  91. global server
  92. server.pooling = 'none'
  93. server.start()
  94. res = server.make_request("POST", "/embeddings", data={
  95. "input": "hello hello hello",
  96. })
  97. assert res.status_code == 200
  98. assert 'embedding' in res.body[0]
  99. assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
  100. # make sure embedding vector is not normalized
  101. for x in res.body[0]['embedding']:
  102. assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
  103. def test_embedding_pooling_none_oai():
  104. global server
  105. server.pooling = 'none'
  106. server.start()
  107. res = server.make_request("POST", "/v1/embeddings", data={
  108. "input": "hello hello hello",
  109. })
  110. # /v1/embeddings does not support pooling type 'none'
  111. assert res.status_code == 400
  112. assert "error" in res.body
  113. def test_embedding_openai_library_single():
  114. global server
  115. server.pooling = 'last'
  116. server.start()
  117. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  118. res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
  119. assert len(res.data) == 1
  120. assert len(res.data[0].embedding) > 1
  121. def test_embedding_openai_library_multiple():
  122. global server
  123. server.pooling = 'last'
  124. server.start()
  125. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  126. res = client.embeddings.create(model="text-embedding-3-small", input=[
  127. "I believe the meaning of life is",
  128. "Write a joke about AI from a very long prompt which will not be truncated",
  129. "This is a test",
  130. "This is another test",
  131. ])
  132. assert len(res.data) == 4
  133. for d in res.data:
  134. assert len(d.embedding) > 1
  135. def test_embedding_error_prompt_too_long():
  136. global server
  137. server.pooling = 'last'
  138. server.start()
  139. res = server.make_request("POST", "/v1/embeddings", data={
  140. "input": "This is a test " * 512,
  141. })
  142. assert res.status_code != 200
  143. assert "too large" in res.body["error"]["message"]
  144. def test_same_prompt_give_same_result():
  145. server.pooling = 'last'
  146. server.start()
  147. res = server.make_request("POST", "/v1/embeddings", data={
  148. "input": [
  149. "I believe the meaning of life is",
  150. "I believe the meaning of life is",
  151. "I believe the meaning of life is",
  152. "I believe the meaning of life is",
  153. "I believe the meaning of life is",
  154. ],
  155. })
  156. assert res.status_code == 200
  157. assert len(res.body['data']) == 5
  158. for i in range(1, len(res.body['data'])):
  159. v0 = res.body['data'][0]['embedding']
  160. vi = res.body['data'][i]['embedding']
  161. for x, y in zip(v0, vi):
  162. assert abs(x - y) < EPSILON
  163. @pytest.mark.parametrize(
  164. "content,n_tokens",
  165. [
  166. ("I believe the meaning of life is", 9),
  167. ("This is a test", 6),
  168. ]
  169. )
  170. def test_embedding_usage_single(content, n_tokens):
  171. global server
  172. server.start()
  173. res = server.make_request("POST", "/v1/embeddings", data={"input": content})
  174. assert res.status_code == 200
  175. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  176. assert res.body['usage']['prompt_tokens'] == n_tokens
  177. def test_embedding_usage_multiple():
  178. global server
  179. server.start()
  180. res = server.make_request("POST", "/v1/embeddings", data={
  181. "input": [
  182. "I believe the meaning of life is",
  183. "I believe the meaning of life is",
  184. ],
  185. })
  186. assert res.status_code == 200
  187. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  188. assert res.body['usage']['prompt_tokens'] == 2 * 9
  189. def test_embedding_openai_library_base64():
  190. server.start()
  191. test_input = "Test base64 embedding output"
  192. # get embedding in default format
  193. res = server.make_request("POST", "/v1/embeddings", data={
  194. "input": test_input
  195. })
  196. assert res.status_code == 200
  197. vec0 = res.body["data"][0]["embedding"]
  198. # get embedding in base64 format
  199. res = server.make_request("POST", "/v1/embeddings", data={
  200. "input": test_input,
  201. "encoding_format": "base64"
  202. })
  203. assert res.status_code == 200
  204. assert "data" in res.body
  205. assert len(res.body["data"]) == 1
  206. embedding_data = res.body["data"][0]
  207. assert "embedding" in embedding_data
  208. assert isinstance(embedding_data["embedding"], str)
  209. # Verify embedding is valid base64
  210. decoded = base64.b64decode(embedding_data["embedding"])
  211. # Verify decoded data can be converted back to float array
  212. float_count = len(decoded) // 4 # 4 bytes per float
  213. floats = struct.unpack(f'{float_count}f', decoded)
  214. assert len(floats) > 0
  215. assert all(isinstance(x, float) for x in floats)
  216. assert len(floats) == len(vec0)
  217. # make sure the decoded data is the same as the original
  218. for x, y in zip(floats, vec0):
  219. assert abs(x - y) < EPSILON