test_tool_call.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. #!/usr/bin/env python
  2. import pytest
  3. # ensure grandparent path is in sys.path
  4. from pathlib import Path
  5. import sys
  6. path = Path(__file__).resolve().parents[1]
  7. sys.path.insert(0, str(path))
  8. from utils import *
  9. server: ServerProcess
  10. TIMEOUT_SERVER_START = 15*60
  11. TIMEOUT_HTTP_REQUEST = 60
  12. @pytest.fixture(autouse=True)
  13. def create_server():
  14. global server
  15. server = ServerPreset.tinyllama2()
  16. server.model_alias = "tinyllama-2-tool-call"
  17. server.server_port = 8081
  18. TEST_TOOL = {
  19. "type":"function",
  20. "function": {
  21. "name": "test",
  22. "description": "",
  23. "parameters": {
  24. "type": "object",
  25. "properties": {
  26. "success": {"type": "boolean", "const": True},
  27. },
  28. "required": ["success"]
  29. }
  30. }
  31. }
  32. PYTHON_TOOL = {
  33. "type": "function",
  34. "function": {
  35. "name": "python",
  36. "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
  37. "parameters": {
  38. "type": "object",
  39. "properties": {
  40. "code": {
  41. "type": "string",
  42. "description": "The code to run in the ipython interpreter."
  43. }
  44. },
  45. "required": ["code"]
  46. }
  47. }
  48. }
  49. WEATHER_TOOL = {
  50. "type":"function",
  51. "function":{
  52. "name":"get_current_weather",
  53. "description":"Get the current weather in a given location",
  54. "parameters":{
  55. "type":"object",
  56. "properties":{
  57. "location":{
  58. "type":"string",
  59. "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'"
  60. }
  61. },
  62. "required":["location"]
  63. }
  64. }
  65. }
  66. def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
  67. res = server.make_request("POST", "/v1/chat/completions", data={
  68. "max_tokens": n_predict,
  69. "messages": [
  70. {"role": "system", "content": "You are a coding assistant."},
  71. {"role": "user", "content": "Write an example"},
  72. ],
  73. "tool_choice": "required",
  74. "tools": [tool],
  75. "parallel_tool_calls": False,
  76. **kwargs,
  77. })
  78. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  79. choice = res.body["choices"][0]
  80. tool_calls = choice["message"].get("tool_calls")
  81. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  82. tool_call = tool_calls[0]
  83. assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  84. assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  85. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  86. assert expected_function_name == tool_call["function"]["name"]
  87. actual_arguments = tool_call["function"]["arguments"]
  88. assert isinstance(actual_arguments, str)
  89. if argument_key is not None:
  90. actual_arguments = json.loads(actual_arguments)
  91. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  92. @pytest.mark.parametrize("template_name,tool,argument_key", [
  93. ("google-gemma-2-2b-it", TEST_TOOL, "success"),
  94. ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
  95. ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
  96. ])
  97. def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
  98. global server
  99. n_predict = 512
  100. # server = ServerPreset.stories15m_moe()
  101. server.jinja = True
  102. server.n_predict = n_predict
  103. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  104. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  105. do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
  106. @pytest.mark.slow
  107. @pytest.mark.parametrize("template_name,tool,argument_key", [
  108. ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
  109. ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
  110. ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
  111. ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
  112. ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
  113. ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
  114. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
  115. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
  116. ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
  117. ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
  118. ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
  119. ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
  120. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
  121. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
  122. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
  123. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
  124. ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
  125. # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
  126. ])
  127. def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
  128. global server
  129. n_predict = 512
  130. # server = ServerPreset.stories15m_moe()
  131. server.jinja = True
  132. server.n_predict = n_predict
  133. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  134. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  135. do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
  136. @pytest.mark.slow
  137. @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
  138. (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  139. (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  140. (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  141. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  142. (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  143. (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
  144. (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  145. (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  146. (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  147. (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  148. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  149. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
  150. (TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  151. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  152. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  153. (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  154. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  155. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  156. (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  157. (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  158. (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  159. (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  160. (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  161. (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  162. (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  163. (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  164. (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  165. (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
  166. (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
  167. (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
  168. (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  169. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  170. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  171. (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  172. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  173. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
  174. (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  175. (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  176. ])
  177. def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
  178. global server
  179. n_predict = 512
  180. server.n_slots = 1
  181. server.jinja = True
  182. server.n_ctx = 8192
  183. server.n_predict = n_predict
  184. server.model_hf_repo = hf_repo
  185. server.model_hf_file = None
  186. if isinstance(template_override, tuple):
  187. (template_hf_repo, template_variant) = template_override
  188. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  189. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  190. elif isinstance(template_override, str):
  191. server.chat_template = template_override
  192. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  193. res = server.make_request("POST", "/v1/chat/completions", data={
  194. "max_tokens": n_predict,
  195. "messages": [
  196. {"role": "system", "content": "You are a coding assistant."},
  197. {"role": "user", "content": "Write an example"},
  198. ],
  199. "tool_choice": "required",
  200. "tools": [tool],
  201. "parallel_tool_calls": False,
  202. "temperature": 0.0,
  203. "top_k": 1,
  204. "top_p": 1.0,
  205. }, timeout=TIMEOUT_HTTP_REQUEST)
  206. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  207. choice = res.body["choices"][0]
  208. tool_calls = choice["message"].get("tool_calls")
  209. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  210. tool_call = tool_calls[0]
  211. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  212. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  213. assert expected_function_name == tool_call["function"]["name"]
  214. actual_arguments = tool_call["function"]["arguments"]
  215. assert isinstance(actual_arguments, str)
  216. if argument_key is not None:
  217. actual_arguments = json.loads(actual_arguments)
  218. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  219. def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
  220. res = server.make_request("POST", "/v1/chat/completions", data={
  221. "max_tokens": n_predict,
  222. "messages": [
  223. {"role": "system", "content": "You are a coding assistant."},
  224. {"role": "user", "content": "say hello world with python"},
  225. ],
  226. "tools": tools if tools else None,
  227. "tool_choice": tool_choice,
  228. **kwargs,
  229. }, timeout=TIMEOUT_HTTP_REQUEST)
  230. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  231. choice = res.body["choices"][0]
  232. assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
  233. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  234. ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
  235. ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
  236. ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
  237. ])
  238. def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
  239. global server
  240. server.jinja = True
  241. server.n_predict = n_predict
  242. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  243. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  244. do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
  245. @pytest.mark.slow
  246. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  247. ("meetkai-functionary-medium-v3.2", 256, [], None),
  248. ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
  249. ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
  250. ("meetkai-functionary-medium-v3.1", 256, [], None),
  251. ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
  252. ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
  253. ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
  254. ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
  255. ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
  256. ])
  257. def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
  258. global server
  259. server.jinja = True
  260. server.n_predict = n_predict
  261. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  262. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  263. do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
  264. @pytest.mark.slow
  265. @pytest.mark.parametrize("hf_repo,template_override", [
  266. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  267. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  268. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  269. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  270. ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
  271. ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
  272. ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  273. ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  274. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  275. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  276. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  277. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  278. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  279. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  280. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  281. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  282. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  283. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
  284. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  285. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  286. ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
  287. ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  288. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
  289. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  290. # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  291. ])
  292. def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
  293. global server
  294. n_predict = 512
  295. server.n_slots = 1
  296. server.jinja = True
  297. server.n_ctx = 8192
  298. server.n_predict = n_predict
  299. server.model_hf_repo = hf_repo
  300. server.model_hf_file = None
  301. if isinstance(template_override, tuple):
  302. (template_hf_repo, template_variant) = template_override
  303. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  304. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  305. elif isinstance(template_override, str):
  306. server.chat_template = template_override
  307. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  308. do_test_weather(server, max_tokens=n_predict)
  309. def do_test_weather(server: ServerProcess, **kwargs):
  310. res = server.make_request("POST", "/v1/chat/completions", data={
  311. "messages": [
  312. {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
  313. {"role": "user", "content": "What is the weather in Istanbul?"},
  314. ],
  315. "tools": [WEATHER_TOOL],
  316. **kwargs,
  317. }, timeout=TIMEOUT_HTTP_REQUEST)
  318. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  319. choice = res.body["choices"][0]
  320. tool_calls = choice["message"].get("tool_calls")
  321. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  322. tool_call = tool_calls[0]
  323. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  324. assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
  325. assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  326. actual_arguments = json.loads(tool_call["function"]["arguments"])
  327. assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
  328. location = actual_arguments["location"]
  329. assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
  330. assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
  331. @pytest.mark.slow
  332. @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
  333. (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  334. (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
  335. (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
  336. (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  337. (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  338. (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  339. (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  340. (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  341. (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  342. (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  343. ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
  344. # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
  345. # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  346. # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  347. ])
  348. def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
  349. global server
  350. server.n_slots = 1
  351. server.jinja = True
  352. server.n_ctx = 8192 * 2
  353. server.n_predict = n_predict
  354. server.model_hf_repo = hf_repo
  355. server.model_hf_file = None
  356. if isinstance(template_override, tuple):
  357. (template_hf_repo, template_variant) = template_override
  358. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  359. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  360. elif isinstance(template_override, str):
  361. server.chat_template = template_override
  362. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  363. do_test_calc_result(server, result_override, n_predict)
  364. def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
  365. res = server.make_request("POST", "/v1/chat/completions", data={
  366. "max_tokens": n_predict,
  367. "messages": [
  368. {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
  369. {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
  370. {
  371. "role": "assistant",
  372. "content": None,
  373. "tool_calls": [
  374. {
  375. "id": "call_6789",
  376. "type": "function",
  377. "function": {
  378. "name": "calculate",
  379. "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
  380. }
  381. }
  382. ]
  383. },
  384. {
  385. "role": "tool",
  386. "name": "calculate",
  387. "content": "0.55644242476",
  388. "tool_call_id": "call_6789"
  389. }
  390. ],
  391. "tools": [
  392. {
  393. "type":"function",
  394. "function":{
  395. "name":"calculate",
  396. "description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
  397. "parameters":{
  398. "type":"object",
  399. "properties":{
  400. "expression":{
  401. "type":"string",
  402. "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
  403. }
  404. },
  405. "required":["expression"]
  406. }
  407. }
  408. }
  409. ],
  410. **kwargs,
  411. }, timeout=TIMEOUT_HTTP_REQUEST)
  412. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  413. choice = res.body["choices"][0]
  414. tool_calls = choice["message"].get("tool_calls")
  415. assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
  416. content = choice["message"].get("content")
  417. assert content is not None, f'Expected content in {choice["message"]}'
  418. if result_override is not None:
  419. assert re.match(result_override, content), f'Expected {result_override}, got {content}'
  420. else:
  421. assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
  422. f'Expected something like "The y coordinate is 0.56.", got {content}'
  423. @pytest.mark.slow
  424. @pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
  425. (128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  426. (128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  427. (1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  428. (1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  429. (1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
  430. ])
  431. def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
  432. global server
  433. server.n_slots = 1
  434. server.reasoning_format = reasoning_format
  435. server.jinja = True
  436. server.n_ctx = 8192 * 2
  437. server.n_predict = n_predict
  438. server.model_hf_repo = hf_repo
  439. server.model_hf_file = None
  440. if isinstance(template_override, tuple):
  441. (template_hf_repo, template_variant) = template_override
  442. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  443. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  444. elif isinstance(template_override, str):
  445. server.chat_template = template_override
  446. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  447. res = server.make_request("POST", "/v1/chat/completions", data={
  448. "max_tokens": n_predict,
  449. "messages": [
  450. {"role": "user", "content": "What's the sum of 102 and 7?"},
  451. ]
  452. }, timeout=TIMEOUT_HTTP_REQUEST)
  453. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  454. choice = res.body["choices"][0]
  455. assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
  456. content = choice["message"].get("content")
  457. if expect_content is None:
  458. assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  459. else:
  460. assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
  461. reasoning_content = choice["message"].get("reasoning_content")
  462. if expect_reasoning_content is None:
  463. assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
  464. else:
  465. assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
  466. @pytest.mark.slow
  467. @pytest.mark.parametrize("hf_repo,template_override", [
  468. ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  469. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  470. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
  471. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
  472. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
  473. # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  474. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
  475. ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  476. ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
  477. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  478. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
  479. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  480. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
  481. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  482. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
  483. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
  484. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
  485. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  486. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
  487. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  488. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
  489. ])
  490. def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
  491. global server
  492. n_predict = 512 # High because of DeepSeek R1
  493. server.n_slots = 1
  494. server.jinja = True
  495. server.n_ctx = 8192
  496. server.n_predict = n_predict
  497. server.model_hf_repo = hf_repo
  498. server.model_hf_file = None
  499. if isinstance(template_override, tuple):
  500. (template_hf_repo, template_variant) = template_override
  501. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  502. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  503. elif isinstance(template_override, str):
  504. server.chat_template = template_override
  505. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  506. do_test_hello_world(server, max_tokens=n_predict)
  507. def do_test_hello_world(server: ServerProcess, **kwargs):
  508. res = server.make_request("POST", "/v1/chat/completions", data={
  509. "messages": [
  510. {"role": "system", "content": "You are a tool-calling agent."},
  511. {"role": "user", "content": "say hello world with python"},
  512. ],
  513. "tools": [PYTHON_TOOL],
  514. **kwargs,
  515. }, timeout=TIMEOUT_HTTP_REQUEST)
  516. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  517. choice = res.body["choices"][0]
  518. tool_calls = choice["message"].get("tool_calls")
  519. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  520. tool_call = tool_calls[0]
  521. # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
  522. assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
  523. assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
  524. actual_arguments = json.loads(tool_call["function"]["arguments"])
  525. assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
  526. code = actual_arguments["code"]
  527. assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
  528. assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'