| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- import pytest
- from utils import *
- import base64
- import requests
- server: ServerProcess
- def get_img_url(id: str) -> str:
- IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
- IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
- if id == "IMG_URL_0":
- return IMG_URL_0
- elif id == "IMG_URL_1":
- return IMG_URL_1
- elif id == "IMG_BASE64_URI_0":
- response = requests.get(IMG_URL_0)
- response.raise_for_status() # Raise an exception for bad status codes
- return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
- elif id == "IMG_BASE64_0":
- response = requests.get(IMG_URL_0)
- response.raise_for_status() # Raise an exception for bad status codes
- return base64.b64encode(response.content).decode("utf-8")
- elif id == "IMG_BASE64_URI_1":
- response = requests.get(IMG_URL_1)
- response.raise_for_status() # Raise an exception for bad status codes
- return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
- elif id == "IMG_BASE64_1":
- response = requests.get(IMG_URL_1)
- response.raise_for_status() # Raise an exception for bad status codes
- return base64.b64encode(response.content).decode("utf-8")
- else:
- return id
- JSON_MULTIMODAL_KEY = "multimodal_data"
- JSON_PROMPT_STRING_KEY = "prompt_string"
- @pytest.fixture(autouse=True)
- def create_server():
- global server
- server = ServerPreset.tinygemma3()
- def test_models_supports_multimodal_capability():
- global server
- server.start()
- res = server.make_request("GET", "/models", data={})
- assert res.status_code == 200
- model_info = res.body["models"][0]
- print(model_info)
- assert "completion" in model_info["capabilities"]
- assert "multimodal" in model_info["capabilities"]
- def test_v1_models_supports_multimodal_capability():
- global server
- server.start()
- res = server.make_request("GET", "/v1/models", data={})
- assert res.status_code == 200
- model_info = res.body["models"][0]
- print(model_info)
- assert "completion" in model_info["capabilities"]
- assert "multimodal" in model_info["capabilities"]
- @pytest.mark.parametrize(
- "prompt, image_url, success, re_content",
- [
- # test model is trained on CIFAR-10, but it's quite dumb due to small size
- ("What is this:\n", "IMG_URL_0", True, "(cat)+"),
- ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
- ("What is this:\n", "IMG_URL_1", True, "(frog)+"),
- ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
- ("What is this:\n", "malformed", False, None),
- ("What is this:\n", "https://google.com/404", False, None), # non-existent image
- ("What is this:\n", "https://ggml.ai", False, None), # non-image data
- # TODO @ngxson : test with multiple images, no images and with audio
- ]
- )
- def test_vision_chat_completion(prompt, image_url, success, re_content):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "temperature": 0.0,
- "top_k": 1,
- "messages": [
- {"role": "user", "content": [
- {"type": "text", "text": prompt},
- {"type": "image_url", "image_url": {
- "url": get_img_url(image_url),
- }},
- ]},
- ],
- })
- if success:
- assert res.status_code == 200
- choice = res.body["choices"][0]
- assert "assistant" == choice["message"]["role"]
- assert match_regex(re_content, choice["message"]["content"])
- else:
- assert res.status_code != 200
- @pytest.mark.parametrize(
- "prompt, image_data, success, re_content",
- [
- # test model is trained on CIFAR-10, but it's quite dumb due to small size
- ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
- ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
- ("What is this: <__media__>\n", "malformed", False, None), # non-image data
- ("What is this:\n", "", False, None), # empty string
- ]
- )
- def test_vision_completion(prompt, image_data, success, re_content):
- global server
- server.start()
- res = server.make_request("POST", "/completions", data={
- "temperature": 0.0,
- "top_k": 1,
- "prompt": {
- JSON_PROMPT_STRING_KEY: prompt,
- JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
- },
- })
- if success:
- assert res.status_code == 200
- content = res.body["content"]
- assert match_regex(re_content, content)
- else:
- assert res.status_code != 200
- @pytest.mark.parametrize(
- "prompt, image_data, success",
- [
- # test model is trained on CIFAR-10, but it's quite dumb due to small size
- ("What is this: <__media__>\n", "IMG_BASE64_0", True),
- ("What is this: <__media__>\n", "IMG_BASE64_1", True),
- ("What is this: <__media__>\n", "malformed", False), # non-image data
- ("What is this:\n", "base64", False), # non-image data
- ]
- )
- def test_vision_embeddings(prompt, image_data, success):
- global server
- server.server_embeddings = True
- server.n_batch = 512
- server.start()
- image_data = get_img_url(image_data)
- res = server.make_request("POST", "/embeddings", data={
- "content": [
- { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
- { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
- { JSON_PROMPT_STRING_KEY: prompt, },
- ],
- })
- if success:
- assert res.status_code == 200
- content = res.body
- # Ensure embeddings are stable when multimodal.
- assert content[0]['embedding'] == content[1]['embedding']
- # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
- assert content[0]['embedding'] != content[2]['embedding']
- else:
- assert res.status_code != 200
|