test_vision_api.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import pytest
  2. from utils import *
  3. import base64
  4. import requests
  5. server: ServerProcess
  6. IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
  7. IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
  8. response = requests.get(IMG_URL_0)
  9. response.raise_for_status() # Raise an exception for bad status codes
  10. IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
  11. @pytest.fixture(autouse=True)
  12. def create_server():
  13. global server
  14. server = ServerPreset.tinygemma3()
  15. @pytest.mark.parametrize(
  16. "prompt, image_url, success, re_content",
  17. [
  18. # test model is trained on CIFAR-10, but it's quite dumb due to small size
  19. ("What is this:\n", IMG_URL_0, True, "(cat)+"),
  20. ("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
  21. ("What is this:\n", IMG_URL_1, True, "(frog)+"),
  22. ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
  23. ("What is this:\n", "malformed", False, None),
  24. ("What is this:\n", "https://google.com/404", False, None), # non-existent image
  25. ("What is this:\n", "https://ggml.ai", False, None), # non-image data
  26. ]
  27. )
  28. def test_vision_chat_completion(prompt, image_url, success, re_content):
  29. global server
  30. server.start(timeout_seconds=60) # vision model may take longer to load due to download size
  31. if image_url == "IMG_BASE64_0":
  32. image_url = IMG_BASE64_0
  33. res = server.make_request("POST", "/chat/completions", data={
  34. "temperature": 0.0,
  35. "top_k": 1,
  36. "messages": [
  37. {"role": "user", "content": [
  38. {"type": "text", "text": prompt},
  39. {"type": "image_url", "image_url": {
  40. "url": image_url,
  41. }},
  42. ]},
  43. ],
  44. })
  45. if success:
  46. assert res.status_code == 200
  47. choice = res.body["choices"][0]
  48. assert "assistant" == choice["message"]["role"]
  49. assert match_regex(re_content, choice["message"]["content"])
  50. else:
  51. assert res.status_code != 200