|
|
@@ -5,18 +5,31 @@ import requests
|
|
|
|
|
|
server: ServerProcess
|
|
|
|
|
|
-IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
|
|
-IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
|
|
-
|
|
|
-response = requests.get(IMG_URL_0)
|
|
|
-response.raise_for_status() # Raise an exception for bad status codes
|
|
|
-IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|
-IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
|
|
|
-
|
|
|
-response = requests.get(IMG_URL_1)
|
|
|
-response.raise_for_status() # Raise an exception for bad status codes
|
|
|
-IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|
-IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
|
|
|
+def get_img_url(id: str) -> str:
|
|
|
+ IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
|
|
+ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
|
|
+ if id == "IMG_URL_0":
|
|
|
+ return IMG_URL_0
|
|
|
+ elif id == "IMG_URL_1":
|
|
|
+ return IMG_URL_1
|
|
|
+ elif id == "IMG_BASE64_URI_0":
|
|
|
+ response = requests.get(IMG_URL_0)
|
|
|
+ response.raise_for_status() # Raise an exception for bad status codes
|
|
|
+ return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|
+ elif id == "IMG_BASE64_0":
|
|
|
+ response = requests.get(IMG_URL_0)
|
|
|
+ response.raise_for_status() # Raise an exception for bad status codes
|
|
|
+ return base64.b64encode(response.content).decode("utf-8")
|
|
|
+ elif id == "IMG_BASE64_URI_1":
|
|
|
+ response = requests.get(IMG_URL_1)
|
|
|
+ response.raise_for_status() # Raise an exception for bad status codes
|
|
|
+ return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|
+ elif id == "IMG_BASE64_1":
|
|
|
+ response = requests.get(IMG_URL_1)
|
|
|
+ response.raise_for_status() # Raise an exception for bad status codes
|
|
|
+ return base64.b64encode(response.content).decode("utf-8")
|
|
|
+ else:
|
|
|
+ return id
|
|
|
|
|
|
JSON_MULTIMODAL_KEY = "multimodal_data"
|
|
|
JSON_PROMPT_STRING_KEY = "prompt_string"
|
|
|
@@ -28,7 +41,7 @@ def create_server():
|
|
|
|
|
|
def test_models_supports_multimodal_capability():
|
|
|
global server
|
|
|
- server.start() # vision model may take longer to load due to download size
|
|
|
+ server.start()
|
|
|
res = server.make_request("GET", "/models", data={})
|
|
|
assert res.status_code == 200
|
|
|
model_info = res.body["models"][0]
|
|
|
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
|
|
|
|
|
|
def test_v1_models_supports_multimodal_capability():
|
|
|
global server
|
|
|
- server.start() # vision model may take longer to load due to download size
|
|
|
+ server.start()
|
|
|
res = server.make_request("GET", "/v1/models", data={})
|
|
|
assert res.status_code == 200
|
|
|
model_info = res.body["models"][0]
|
|
|
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
|
|
|
"prompt, image_url, success, re_content",
|
|
|
[
|
|
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
|
|
- ("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
|
|
- ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
|
|
- ("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
|
|
- ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
|
|
+ ("What is this:\n", "IMG_URL_0", True, "(cat)+"),
|
|
|
+ ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
|
|
|
+ ("What is this:\n", "IMG_URL_1", True, "(frog)+"),
|
|
|
+ ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
|
|
|
("What is this:\n", "malformed", False, None),
|
|
|
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
|
|
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
|
|
@@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability():
|
|
|
)
|
|
|
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|
|
global server
|
|
|
- server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
|
|
- if image_url == "IMG_BASE64_URI_0":
|
|
|
- image_url = IMG_BASE64_URI_0
|
|
|
+ server.start()
|
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
|
"temperature": 0.0,
|
|
|
"top_k": 1,
|
|
|
@@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|
|
{"role": "user", "content": [
|
|
|
{"type": "text", "text": prompt},
|
|
|
{"type": "image_url", "image_url": {
|
|
|
- "url": image_url,
|
|
|
+ "url": get_img_url(image_url),
|
|
|
}},
|
|
|
]},
|
|
|
],
|
|
|
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|
|
"prompt, image_data, success, re_content",
|
|
|
[
|
|
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
|
|
- ("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
|
|
|
- ("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
|
|
|
+ ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
|
|
|
+ ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
|
|
|
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
|
|
("What is this:\n", "", False, None), # empty string
|
|
|
]
|
|
|
)
|
|
|
def test_vision_completion(prompt, image_data, success, re_content):
|
|
|
global server
|
|
|
- server.start() # vision model may take longer to load due to download size
|
|
|
+ server.start()
|
|
|
res = server.make_request("POST", "/completions", data={
|
|
|
"temperature": 0.0,
|
|
|
"top_k": 1,
|
|
|
- "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
|
|
+ "prompt": {
|
|
|
+ JSON_PROMPT_STRING_KEY: prompt,
|
|
|
+ JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
|
|
|
+ },
|
|
|
})
|
|
|
if success:
|
|
|
assert res.status_code == 200
|
|
|
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
|
|
|
"prompt, image_data, success",
|
|
|
[
|
|
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
|
|
- ("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
|
|
|
- ("What is this: <__media__>\n", IMG_BASE64_1, True),
|
|
|
+ ("What is this: <__media__>\n", "IMG_BASE64_0", True),
|
|
|
+ ("What is this: <__media__>\n", "IMG_BASE64_1", True),
|
|
|
("What is this: <__media__>\n", "malformed", False), # non-image data
|
|
|
("What is this:\n", "base64", False), # non-image data
|
|
|
]
|
|
|
)
|
|
|
def test_vision_embeddings(prompt, image_data, success):
|
|
|
global server
|
|
|
- server.server_embeddings=True
|
|
|
- server.n_batch=512
|
|
|
- server.start() # vision model may take longer to load due to download size
|
|
|
+ server.server_embeddings = True
|
|
|
+ server.n_batch = 512
|
|
|
+ server.start()
|
|
|
+ image_data = get_img_url(image_data)
|
|
|
res = server.make_request("POST", "/embeddings", data={
|
|
|
"content": [
|
|
|
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|