test_tool_call.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import pytest
  2. from utils import *
  3. server: ServerProcess
  4. TIMEOUT_SERVER_START = 15*60
  5. TIMEOUT_HTTP_REQUEST = 60
  6. @pytest.fixture(autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.tinyllama2()
  10. server.model_alias = "tinyllama-2-tool-call"
  11. server.server_port = 8081
  12. TEST_TOOL = {
  13. "type":"function",
  14. "function": {
  15. "name": "test",
  16. "description": "",
  17. "parameters": {
  18. "type": "object",
  19. "properties": {
  20. "success": {"type": "boolean", "const": True},
  21. },
  22. "required": ["success"]
  23. }
  24. }
  25. }
  26. PYTHON_TOOL = {
  27. "type": "function",
  28. "function": {
  29. "name": "python",
  30. "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
  31. "parameters": {
  32. "type": "object",
  33. "properties": {
  34. "code": {
  35. "type": "string",
  36. "description": "The code to run in the ipython interpreter."
  37. }
  38. },
  39. "required": ["code"]
  40. }
  41. }
  42. }
  43. WEATHER_TOOL = {
  44. "type":"function",
  45. "function":{
  46. "name":"get_current_weather",
  47. "description":"Get the current weather in a given location",
  48. "parameters":{
  49. "type":"object",
  50. "properties":{
  51. "location":{
  52. "type":"string",
  53. "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'"
  54. }
  55. },
  56. "required":["location"]
  57. }
  58. }
  59. }
  60. def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
  61. n_predict = 512
  62. global server
  63. # server = ServerPreset.stories15m_moe()
  64. server.jinja = True
  65. server.n_predict = n_predict
  66. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  67. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  68. res = server.make_request("POST", "/chat/completions", data={
  69. "max_tokens": n_predict,
  70. "messages": [
  71. {"role": "system", "content": "You are a coding assistant."},
  72. {"role": "user", "content": "Write an example"},
  73. ],
  74. "tool_choice": "required",
  75. "tools": [tool],
  76. "parallel_tool_calls": False,
  77. "temperature": 0.0,
  78. "top_k": 1,
  79. "top_p": 1.0,
  80. })
  81. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  82. choice = res.body["choices"][0]
  83. tool_calls = choice["message"].get("tool_calls")
  84. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  85. tool_call = tool_calls[0]
  86. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  87. assert expected_function_name == tool_call["function"]["name"]
  88. actual_arguments = tool_call["function"]["arguments"]
  89. assert isinstance(actual_arguments, str)
  90. if argument_key is not None:
  91. actual_arguments = json.loads(actual_arguments)
  92. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  93. @pytest.mark.parametrize("template_name,tool,argument_key", [
  94. ("google-gemma-2-2b-it", TEST_TOOL, "success"),
  95. ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
  96. ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
  97. ])
  98. def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
  99. do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
  100. @pytest.mark.slow
  101. @pytest.mark.parametrize("template_name,tool,argument_key", [
  102. ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
  103. ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
  104. ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
  105. ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
  106. ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
  107. ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
  108. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
  109. ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
  110. ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
  111. ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
  112. ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
  113. ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
  114. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
  115. ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
  116. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
  117. ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
  118. ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
  119. ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
  120. ])
  121. def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
  122. do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
  123. @pytest.mark.slow
  124. @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
  125. (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  126. (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  127. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  128. (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  129. (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  130. (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  131. (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  132. (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  133. (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  134. (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  135. (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  136. (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  137. (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  138. (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  139. (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  140. (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  141. (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  142. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  143. (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  144. (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  145. # TODO: fix these
  146. # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  147. # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  148. ])
  149. def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
  150. n_predict = 512
  151. server.n_slots = 1
  152. server.jinja = True
  153. server.n_ctx = 8192
  154. server.n_predict = n_predict
  155. server.model_hf_repo = hf_repo
  156. server.model_hf_file = None
  157. if template_override:
  158. (template_hf_repo, template_variant) = template_override
  159. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  160. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  161. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  162. res = server.make_request("POST", "/chat/completions", data={
  163. "max_tokens": n_predict,
  164. "messages": [
  165. {"role": "system", "content": "You are a coding assistant."},
  166. {"role": "user", "content": "Write an example"},
  167. ],
  168. "tool_choice": "required",
  169. "tools": [tool],
  170. "parallel_tool_calls": False,
  171. "temperature": 0.0,
  172. "top_k": 1,
  173. "top_p": 1.0,
  174. }, timeout=TIMEOUT_HTTP_REQUEST)
  175. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  176. choice = res.body["choices"][0]
  177. tool_calls = choice["message"].get("tool_calls")
  178. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  179. tool_call = tool_calls[0]
  180. expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
  181. assert expected_function_name == tool_call["function"]["name"]
  182. actual_arguments = tool_call["function"]["arguments"]
  183. assert isinstance(actual_arguments, str)
  184. if argument_key is not None:
  185. actual_arguments = json.loads(actual_arguments)
  186. assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
  187. def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
  188. global server
  189. server.jinja = True
  190. server.n_predict = n_predict
  191. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  192. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  193. res = server.make_request("POST", "/chat/completions", data={
  194. "max_tokens": n_predict,
  195. "messages": [
  196. {"role": "system", "content": "You are a coding assistant."},
  197. {"role": "user", "content": "say hello world with python"},
  198. ],
  199. "tools": tools if tools else None,
  200. "tool_choice": tool_choice,
  201. "temperature": 0.0,
  202. "top_k": 1,
  203. "top_p": 1.0,
  204. }, timeout=TIMEOUT_HTTP_REQUEST)
  205. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  206. choice = res.body["choices"][0]
  207. assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
  208. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  209. ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
  210. ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
  211. ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
  212. ])
  213. def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
  214. do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
  215. @pytest.mark.slow
  216. @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
  217. ("meetkai-functionary-medium-v3.2", 256, [], None),
  218. ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
  219. ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
  220. ("meetkai-functionary-medium-v3.1", 256, [], None),
  221. ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
  222. ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
  223. ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
  224. ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
  225. ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
  226. ])
  227. def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
  228. do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
  229. @pytest.mark.slow
  230. @pytest.mark.parametrize("hf_repo,template_override", [
  231. ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  232. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  233. ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  234. ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  235. ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  236. ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
  237. ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  238. ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
  239. ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  240. # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
  241. # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  242. ])
  243. def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None):
  244. global server
  245. server.n_slots = 1
  246. server.jinja = True
  247. server.n_ctx = 8192
  248. server.n_predict = 512
  249. server.model_hf_repo = hf_repo
  250. server.model_hf_file = None
  251. if template_override:
  252. (template_hf_repo, template_variant) = template_override
  253. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  254. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  255. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  256. res = server.make_request("POST", "/chat/completions", data={
  257. "max_tokens": 256,
  258. "messages": [
  259. {"role": "user", "content": "What is the weather in Istanbul?"},
  260. ],
  261. "tools": [WEATHER_TOOL],
  262. }, timeout=TIMEOUT_HTTP_REQUEST)
  263. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  264. choice = res.body["choices"][0]
  265. tool_calls = choice["message"].get("tool_calls")
  266. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  267. tool_call = tool_calls[0]
  268. assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
  269. actual_arguments = json.loads(tool_call["function"]["arguments"])
  270. assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
  271. location = actual_arguments["location"]
  272. assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
  273. assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
  274. @pytest.mark.slow
  275. @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
  276. (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
  277. (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
  278. (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
  279. ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
  280. (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  281. ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
  282. (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
  283. (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
  284. (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
  285. (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
  286. # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
  287. ])
  288. def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
  289. global server
  290. server.n_slots = 1
  291. server.jinja = True
  292. server.n_ctx = 8192
  293. server.n_predict = 128
  294. server.model_hf_repo = hf_repo
  295. server.model_hf_file = None
  296. if template_override:
  297. (template_hf_repo, template_variant) = template_override
  298. server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
  299. assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
  300. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  301. res = server.make_request("POST", "/chat/completions", data={
  302. "max_tokens": 256,
  303. "messages": [
  304. {"role": "system", "content": "You are a coding assistant."},
  305. {"role": "user", "content": "say hello world with python"},
  306. ],
  307. "tools": [PYTHON_TOOL],
  308. # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
  309. "temperature": 0.0,
  310. "top_k": 1,
  311. "top_p": 1.0,
  312. }, timeout=TIMEOUT_HTTP_REQUEST)
  313. assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
  314. choice = res.body["choices"][0]
  315. tool_calls = choice["message"].get("tool_calls")
  316. assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
  317. tool_call = tool_calls[0]
  318. assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
  319. actual_arguments = tool_call["function"]["arguments"]
  320. if expected_arguments_override is not None:
  321. assert actual_arguments == expected_arguments_override
  322. else:
  323. actual_arguments = json.loads(actual_arguments)
  324. assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
  325. code = actual_arguments["code"]
  326. assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
  327. assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'