test_tool_call.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #!/usr/bin/env python
  2. import pytest
  3. # ensure grandparent path is in sys.path
  4. from pathlib import Path
  5. import sys
  6. path = Path(__file__).resolve().parents[1]
  7. sys.path.insert(0, str(path))
  8. from utils import *
  9. from enum import Enum
  10. server: ServerProcess
  11. TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
  12. TIMEOUT_HTTP_REQUEST = 60
  13. @pytest.fixture(autouse=True)
  14. def create_server():
  15. global server
  16. server = ServerPreset.tinyllama2()
  17. server.model_alias = "tinyllama-2-tool-call"
  18. server.server_port = 8081
  19. server.n_slots = 1
  20. server.n_ctx = 8192
  21. server.n_batch = 2048
  22. class CompletionMode(Enum):
  23. NORMAL = "normal"
  24. STREAMED = "streamed"
  25. TEST_TOOL = {
  26. "type":"function",
  27. "function": {
  28. "name": "test",
  29. "description": "",
  30. "parameters": {
  31. "type": "object",
  32. "properties": {
  33. "success": {"type": "boolean", "const": True},
  34. },
  35. "required": ["success"]
  36. }
  37. }
  38. }
  39. PYTHON_TOOL = {
  40. "type": "function",
  41. "function": {
  42. "name": "python",
  43. "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
  44. "parameters": {
  45. "type": "object",
  46. "properties": {
  47. "code": {
  48. "type": "string",
  49. "description": "The code to run in the ipython interpreter."
  50. }
  51. },
  52. "required": ["code"]
  53. }
  54. }
  55. }
  56. WEATHER_TOOL = {
  57. "type":"function",
  58. "function":{
  59. "name":"get_current_weather",
  60. "description":"Get the current weather in a given location",
  61. "parameters":{
  62. "type":"object",
  63. "properties":{
  64. "location":{
  65. "type":"string",
  66. "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'"
  67. }
  68. },
  69. "required":["location"]
  70. }
  71. }
  72. }
  73. def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
  74. body = server.make_any_request("POST", "/v1/chat/completions", data={
  75. "max_tokens": n_predict,
  76. "messages": [
  77. {"role": "system", "content": "You are a coding assistant."},
  78. {"role": "user", "content": "Write an example"},
  79. ],
  80. "tool_choice": "required",
  81. "tools": [tool],
  82. "parallel_tool_calls": False,
  83. **kwargs,
  84. })
  85. # assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  86. choice = body["choices"][0]
  87. tool_calls = choice["message"].get("tool_calls")
  88. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  89. tool_call = tool_calls[0]
  90. assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  91. # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  92. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  93. assert expected_function_name == tool_call["function"]["name"]
  94. actual_arguments = tool_call["function"]["arguments"]
  95. assert isinstance(actual_arguments, str)
  96. if argument_key is not None:
  97. actual_arguments = json.loads(actual_arguments)
  98. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  99. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  100. @pytest.mark.parametrize("template_name,tool,argument_key", [
  101. ("google-gemma-2-2b-it", TEST_TOOL, "success"),
  102. ("google-gemma-2-2b-it", TEST_TOOL, "success"),
  103. ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
  104. ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
  105. ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
  106. ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
  107. ])
  108. def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
  109. global server
  110. n_predict = 1024
  111. # server = ServerPreset.stories15m_moe()
  112. server.jinja = True
  113. server.n_predict = n_predict
  114. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  115. server.start()
  116. do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
  117. @pytest.mark.slow
  118. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  119. @pytest.mark.parametrize("template_name,tool,argument_key", [
  120. ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
  121. ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
  122. ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
  123. ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
  124. ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
  125. # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
  126. # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
  127. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
  128. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
  129. ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
  130. ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
  131. ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
  132. ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
  133. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
  134. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
  135. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
  136. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
  137. ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
  138. # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
  139. # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
  140. ])
  141. def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
  142. global server
  143. n_predict = 512
  144. # server = ServerPreset.stories15m_moe()
  145. server.jinja = True
  146. server.n_predict = n_predict
  147. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  148. server.start(timeout_seconds=TIMEOUT_START_SLOW)
  149. do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
  150. @pytest.mark.slow
  151. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  152. @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
  153. (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  154. (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  155. (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  156. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  157. (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  158. (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
  159. (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  160. (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  161. (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  162. (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  163. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  164. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
  165. (TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  166. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  167. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  168. (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  169. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  170. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  171. (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  172. (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  173. (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  174. (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  175. (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  176. (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  177. # (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  178. # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  179. # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  180. (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
  181. (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
  182. (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
  183. (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  184. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  185. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  186. (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  187. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  188. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
  189. (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  190. (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  191. ])
  192. def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
  193. global server
  194. n_predict = 512
  195. server.jinja = True
  196. server.n_ctx = 8192
  197. server.n_predict = n_predict
  198. server.model_hf_repo = hf_repo
  199. server.model_hf_file = None
  200. if isinstance(template_override, tuple):
  201. (template_hf_repo, template_variant) = template_override
  202. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  203. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  204. elif isinstance(template_override, str):
  205. server.chat_template = template_override
  206. server.start(timeout_seconds=TIMEOUT_START_SLOW)
  207. body = server.make_any_request("POST", "/v1/chat/completions", data={
  208. "max_tokens": n_predict,
  209. "messages": [
  210. {"role": "system", "content": "You are a coding assistant."},
  211. {"role": "user", "content": "Write an example"},
  212. ],
  213. "tool_choice": "required",
  214. "tools": [tool],
  215. "parallel_tool_calls": False,
  216. "stream": stream == CompletionMode.STREAMED,
  217. "temperature": 0.0,
  218. "top_k": 1,
  219. "top_p": 1.0,
  220. }, timeout=TIMEOUT_HTTP_REQUEST)
  221. choice = body["choices"][0]
  222. tool_calls = choice["message"].get("tool_calls")
  223. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  224. tool_call = tool_calls[0]
  225. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  226. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  227. assert expected_function_name == tool_call["function"]["name"]
  228. actual_arguments = tool_call["function"]["arguments"]
  229. assert isinstance(actual_arguments, str)
  230. if argument_key is not None:
  231. actual_arguments = json.loads(actual_arguments)
  232. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  233. def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
  234. body = server.make_any_request("POST", "/v1/chat/completions", data={
  235. "max_tokens": n_predict,
  236. "messages": [
  237. {"role": "system", "content": "You are a coding assistant."},
  238. {"role": "user", "content": "say hello world with python"},
  239. ],
  240. "tools": tools if tools else None,
  241. "tool_choice": tool_choice,
  242. **kwargs,
  243. }, timeout=TIMEOUT_HTTP_REQUEST)
  244. choice = body["choices"][0]
  245. assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
  246. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  247. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  248. ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
  249. ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
  250. ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
  251. ])
  252. def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
  253. global server
  254. server.n_predict = n_predict
  255. server.jinja = True
  256. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  257. server.start()
  258. do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
  259. @pytest.mark.slow
  260. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  261. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  262. ("meetkai-functionary-medium-v3.2", 256, [], None),
  263. ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
  264. ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
  265. ("meetkai-functionary-medium-v3.1", 256, [], None),
  266. ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
  267. ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
  268. ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
  269. ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
  270. ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
  271. ])
  272. def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
  273. global server
  274. server.n_predict = n_predict
  275. server.jinja = True
  276. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  277. server.start(timeout_seconds=TIMEOUT_START_SLOW)
  278. do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
  279. @pytest.mark.slow
  280. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  281. @pytest.mark.parametrize("hf_repo,template_override", [
  282. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  283. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  284. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  285. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  286. ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  287. ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
  288. ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  289. ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  290. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  291. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  292. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  293. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  294. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  295. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  296. # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  297. # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  298. # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  299. # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
  300. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  301. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  302. ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
  303. ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  304. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
  305. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  306. # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  307. ])
  308. def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
  309. global server
  310. n_predict = 512
  311. server.jinja = True
  312. server.n_ctx = 8192
  313. server.n_predict = n_predict
  314. server.model_hf_repo = hf_repo
  315. server.model_hf_file = None
  316. if isinstance(template_override, tuple):
  317. (template_hf_repo, template_variant) = template_override
  318. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  319. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  320. elif isinstance(template_override, str):
  321. server.chat_template = template_override
  322. server.start()
  323. do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
  324. def do_test_weather(server: ServerProcess, **kwargs):
  325. body = server.make_any_request("POST", "/v1/chat/completions", data={
  326. "messages": [
  327. {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
  328. {"role": "user", "content": "What is the weather in Istanbul?"},
  329. ],
  330. "tools": [WEATHER_TOOL],
  331. **kwargs,
  332. }, timeout=TIMEOUT_HTTP_REQUEST)
  333. choice = body["choices"][0]
  334. tool_calls = choice["message"].get("tool_calls")
  335. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  336. tool_call = tool_calls[0]
  337. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  338. assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
  339. # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  340. actual_arguments = json.loads(tool_call["function"]["arguments"])
  341. assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
  342. location = actual_arguments["location"]
  343. assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
  344. assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
  345. @pytest.mark.slow
  346. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  347. @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
  348. (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  349. (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  350. (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  351. (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  352. (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  353. (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  354. (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  355. (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  356. (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  357. (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  358. ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
  359. # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
  360. # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  361. # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  362. ])
  363. def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
  364. global server
  365. server.jinja = True
  366. server.n_ctx = 8192 * 2
  367. server.n_predict = n_predict
  368. server.model_hf_repo = hf_repo
  369. server.model_hf_file = None
  370. if isinstance(template_override, tuple):
  371. (template_hf_repo, template_variant) = template_override
  372. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  373. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  374. elif isinstance(template_override, str):
  375. server.chat_template = template_override
  376. server.start(timeout_seconds=TIMEOUT_START_SLOW)
  377. do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
  378. def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
  379. body = server.make_any_request("POST", "/v1/chat/completions", data={
  380. "max_tokens": n_predict,
  381. "messages": [
  382. {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
  383. {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
  384. {
  385. "role": "assistant",
  386. "content": None,
  387. "tool_calls": [
  388. {
  389. "id": "call_6789",
  390. "type": "function",
  391. "function": {
  392. "name": "calculate",
  393. "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
  394. }
  395. }
  396. ]
  397. },
  398. {
  399. "role": "tool",
  400. "name": "calculate",
  401. "content": "0.55644242476",
  402. "tool_call_id": "call_6789"
  403. }
  404. ],
  405. "tools": [
  406. {
  407. "type":"function",
  408. "function":{
  409. "name":"calculate",
  410. "description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
  411. "parameters":{
  412. "type":"object",
  413. "properties":{
  414. "expression":{
  415. "type":"string",
  416. "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
  417. }
  418. },
  419. "required":["expression"]
  420. }
  421. }
  422. }
  423. ],
  424. **kwargs,
  425. }, timeout=TIMEOUT_HTTP_REQUEST)
  426. choice = body["choices"][0]
  427. tool_calls = choice["message"].get("tool_calls")
  428. assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
  429. content = choice["message"].get("content")
  430. assert content is not None, f'Expected content in {choice["message"]}'
  431. if result_override is not None:
  432. assert re.match(result_override, content), f'Expected {result_override}, got {content}'
  433. else:
  434. assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
  435. f'Expected something like "The y coordinate is 0.56.", got {content}'
  436. @pytest.mark.slow
  437. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  438. @pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [
  439. (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  440. (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  441. (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  442. (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
  443. # (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  444. # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
  445. ])
  446. def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
  447. global server
  448. server.reasoning_format = reasoning_format
  449. server.jinja = True
  450. server.n_ctx = 8192 * 2
  451. server.n_predict = n_predict
  452. server.model_hf_repo = hf_repo
  453. server.model_hf_file = None
  454. if isinstance(template_override, tuple):
  455. (template_hf_repo, template_variant) = template_override
  456. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  457. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  458. elif isinstance(template_override, str):
  459. server.chat_template = template_override
  460. server.start()
  461. body = server.make_any_request("POST", "/v1/chat/completions", data={
  462. "max_tokens": n_predict,
  463. "messages": [
  464. {"role": "user", "content": "What's the sum of 102 and 7?"},
  465. ],
  466. "stream": stream == CompletionMode.STREAMED,
  467. }, timeout=TIMEOUT_HTTP_REQUEST)
  468. choice = body["choices"][0]
  469. assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
  470. content = choice["message"].get("content")
  471. if expect_content is None:
  472. assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  473. else:
  474. assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
  475. reasoning_content = choice["message"].get("reasoning_content")
  476. if expect_reasoning_content is None:
  477. assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
  478. else:
  479. assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
  480. @pytest.mark.slow
  481. @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
  482. @pytest.mark.parametrize("hf_repo,template_override", [
  483. ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  484. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  485. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  486. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
  487. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
  488. # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  489. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  490. ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  491. ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
  492. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  493. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
  494. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  495. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  496. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  497. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  498. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
  499. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  500. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  501. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  502. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  503. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
  504. ])
  505. def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
  506. global server
  507. n_predict = 512 # High because of DeepSeek R1
  508. server.jinja = True
  509. server.n_ctx = 8192
  510. server.n_predict = n_predict
  511. server.model_hf_repo = hf_repo
  512. server.model_hf_file = None
  513. if isinstance(template_override, tuple):
  514. (template_hf_repo, template_variant) = template_override
  515. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  516. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  517. elif isinstance(template_override, str):
  518. server.chat_template = template_override
  519. server.start(timeout_seconds=TIMEOUT_START_SLOW)
  520. do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
  521. def do_test_hello_world(server: ServerProcess, **kwargs):
  522. body = server.make_any_request("POST", "/v1/chat/completions", data={
  523. "messages": [
  524. {"role": "system", "content": "You are a tool-calling agent."},
  525. {"role": "user", "content": "say hello world with python"},
  526. ],
  527. "tools": [PYTHON_TOOL],
  528. **kwargs,
  529. }, timeout=TIMEOUT_HTTP_REQUEST)
  530. choice = body["choices"][0]
  531. tool_calls = choice["message"].get("tool_calls")
  532. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  533. tool_call = tool_calls[0]
  534. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  535. assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
  536. # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  537. actual_arguments = json.loads(tool_call["function"]["arguments"])
  538. assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
  539. code = actual_arguments["code"]
  540. assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
  541. assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'