test_vision_api.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import pytest
  2. from utils import *
  3. import base64
  4. import requests
  5. server: ServerProcess
  6. IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
  7. IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
  8. response = requests.get(IMG_URL_0)
  9. response.raise_for_status() # Raise an exception for bad status codes
  10. IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
  11. IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
  12. response = requests.get(IMG_URL_1)
  13. response.raise_for_status() # Raise an exception for bad status codes
  14. IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
  15. IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
  16. JSON_MULTIMODAL_KEY = "multimodal_data"
  17. JSON_PROMPT_STRING_KEY = "prompt_string"
  18. @pytest.fixture(autouse=True)
  19. def create_server():
  20. global server
  21. server = ServerPreset.tinygemma3()
  22. def test_models_supports_multimodal_capability():
  23. global server
  24. server.start() # vision model may take longer to load due to download size
  25. res = server.make_request("GET", "/models", data={})
  26. assert res.status_code == 200
  27. model_info = res.body["models"][0]
  28. print(model_info)
  29. assert "completion" in model_info["capabilities"]
  30. assert "multimodal" in model_info["capabilities"]
  31. def test_v1_models_supports_multimodal_capability():
  32. global server
  33. server.start() # vision model may take longer to load due to download size
  34. res = server.make_request("GET", "/v1/models", data={})
  35. assert res.status_code == 200
  36. model_info = res.body["models"][0]
  37. print(model_info)
  38. assert "completion" in model_info["capabilities"]
  39. assert "multimodal" in model_info["capabilities"]
  40. @pytest.mark.parametrize(
  41. "prompt, image_url, success, re_content",
  42. [
  43. # test model is trained on CIFAR-10, but it's quite dumb due to small size
  44. ("What is this:\n", IMG_URL_0, True, "(cat)+"),
  45. ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
  46. ("What is this:\n", IMG_URL_1, True, "(frog)+"),
  47. ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
  48. ("What is this:\n", "malformed", False, None),
  49. ("What is this:\n", "https://google.com/404", False, None), # non-existent image
  50. ("What is this:\n", "https://ggml.ai", False, None), # non-image data
  51. # TODO @ngxson : test with multiple images, no images and with audio
  52. ]
  53. )
  54. def test_vision_chat_completion(prompt, image_url, success, re_content):
  55. global server
  56. server.start(timeout_seconds=60) # vision model may take longer to load due to download size
  57. if image_url == "IMG_BASE64_URI_0":
  58. image_url = IMG_BASE64_URI_0
  59. res = server.make_request("POST", "/chat/completions", data={
  60. "temperature": 0.0,
  61. "top_k": 1,
  62. "messages": [
  63. {"role": "user", "content": [
  64. {"type": "text", "text": prompt},
  65. {"type": "image_url", "image_url": {
  66. "url": image_url,
  67. }},
  68. ]},
  69. ],
  70. })
  71. if success:
  72. assert res.status_code == 200
  73. choice = res.body["choices"][0]
  74. assert "assistant" == choice["message"]["role"]
  75. assert match_regex(re_content, choice["message"]["content"])
  76. else:
  77. assert res.status_code != 200
  78. @pytest.mark.parametrize(
  79. "prompt, image_data, success, re_content",
  80. [
  81. # test model is trained on CIFAR-10, but it's quite dumb due to small size
  82. ("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
  83. ("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
  84. ("What is this: <__media__>\n", "malformed", False, None), # non-image data
  85. ("What is this:\n", "", False, None), # empty string
  86. ]
  87. )
  88. def test_vision_completion(prompt, image_data, success, re_content):
  89. global server
  90. server.start() # vision model may take longer to load due to download size
  91. res = server.make_request("POST", "/completions", data={
  92. "temperature": 0.0,
  93. "top_k": 1,
  94. "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
  95. })
  96. if success:
  97. assert res.status_code == 200
  98. content = res.body["content"]
  99. assert match_regex(re_content, content)
  100. else:
  101. assert res.status_code != 200
  102. @pytest.mark.parametrize(
  103. "prompt, image_data, success",
  104. [
  105. # test model is trained on CIFAR-10, but it's quite dumb due to small size
  106. ("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
  107. ("What is this: <__media__>\n", IMG_BASE64_1, True),
  108. ("What is this: <__media__>\n", "malformed", False), # non-image data
  109. ("What is this:\n", "base64", False), # non-image data
  110. ]
  111. )
  112. def test_vision_embeddings(prompt, image_data, success):
  113. global server
  114. server.server_embeddings=True
  115. server.n_batch=512
  116. server.start() # vision model may take longer to load due to download size
  117. res = server.make_request("POST", "/embeddings", data={
  118. "content": [
  119. { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
  120. { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
  121. { JSON_PROMPT_STRING_KEY: prompt, },
  122. ],
  123. })
  124. if success:
  125. assert res.status_code == 200
  126. content = res.body
  127. # Ensure embeddings are stable when multimodal.
  128. assert content[0]['embedding'] == content[1]['embedding']
  129. # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
  130. assert content[0]['embedding'] != content[2]['embedding']
  131. else:
  132. assert res.status_code != 200