test_template.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/env python
  2. import pytest
  3. # ensure grandparent path is in sys.path
  4. from pathlib import Path
  5. import sys
  6. from unit.test_tool_call import TEST_TOOL
  7. path = Path(__file__).resolve().parents[1]
  8. sys.path.insert(0, str(path))
  9. import datetime
  10. from utils import *
  11. server: ServerProcess
  12. TIMEOUT_SERVER_START = 15*60
  13. @pytest.fixture(autouse=True)
  14. def create_server():
  15. global server
  16. server = ServerPreset.tinyllama2()
  17. server.model_alias = "tinyllama-2"
  18. server.server_port = 8081
  19. server.n_slots = 1
  20. @pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
  21. @pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [
  22. ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "<think>\n"),
  23. ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "<think>\n"),
  24. ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "<think>\n</think>"),
  25. ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"),
  26. ("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
  27. ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n<think>\n"),
  28. ("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n<think>\n</think>"),
  29. ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
  30. ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
  31. ])
  32. def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]):
  33. global server
  34. server.jinja = True
  35. server.reasoning_budget = reasoning_budget
  36. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  37. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  38. res = server.make_request("POST", "/apply-template", data={
  39. "messages": [
  40. {"role": "user", "content": "What is today?"},
  41. ],
  42. "tools": tools,
  43. })
  44. assert res.status_code == 200
  45. prompt = res.body["prompt"]
  46. assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
  47. @pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
  48. @pytest.mark.parametrize("template_name,format", [
  49. ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
  50. ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
  51. ])
  52. def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
  53. global server
  54. server.jinja = True
  55. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  56. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  57. res = server.make_request("POST", "/apply-template", data={
  58. "messages": [
  59. {"role": "user", "content": "What is today?"},
  60. ],
  61. "tools": tools,
  62. })
  63. assert res.status_code == 200
  64. prompt = res.body["prompt"]
  65. today_str = datetime.date.today().strftime(format)
  66. assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
  67. @pytest.mark.parametrize("add_generation_prompt", [False, True])
  68. @pytest.mark.parametrize("template_name,expected_generation_prompt", [
  69. ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
  70. ])
  71. def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
  72. global server
  73. server.jinja = True
  74. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  75. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  76. res = server.make_request("POST", "/apply-template", data={
  77. "messages": [
  78. {"role": "user", "content": "What is today?"},
  79. ],
  80. "add_generation_prompt": add_generation_prompt,
  81. })
  82. assert res.status_code == 200
  83. prompt = res.body["prompt"]
  84. if add_generation_prompt:
  85. assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
  86. else:
  87. assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"