test_template.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python
  2. import pytest
  3. # ensure grandparent path is in sys.path
  4. from pathlib import Path
  5. import sys
  6. from unit.test_tool_call import TEST_TOOL
  7. path = Path(__file__).resolve().parents[1]
  8. sys.path.insert(0, str(path))
  9. import datetime
  10. from utils import *
  11. server: ServerProcess
  12. TIMEOUT_SERVER_START = 15*60
  13. @pytest.fixture(autouse=True)
  14. def create_server():
  15. global server
  16. server = ServerPreset.tinyllama2()
  17. server.model_alias = "tinyllama-2"
  18. server.server_port = 8081
  19. server.n_slots = 1
  20. @pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
  21. @pytest.mark.parametrize("template_name,format", [
  22. ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
  23. ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
  24. ])
  25. def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
  26. global server
  27. server.jinja = True
  28. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  29. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  30. res = server.make_request("POST", "/apply-template", data={
  31. "messages": [
  32. {"role": "user", "content": "What is today?"},
  33. ],
  34. "tools": tools,
  35. })
  36. assert res.status_code == 200
  37. prompt = res.body["prompt"]
  38. today_str = datetime.date.today().strftime(format)
  39. assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
  40. @pytest.mark.parametrize("add_generation_prompt", [False, True])
  41. @pytest.mark.parametrize("template_name,expected_generation_prompt", [
  42. ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
  43. ])
  44. def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
  45. global server
  46. server.jinja = True
  47. server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
  48. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  49. res = server.make_request("POST", "/apply-template", data={
  50. "messages": [
  51. {"role": "user", "content": "What is today?"},
  52. ],
  53. "add_generation_prompt": add_generation_prompt,
  54. })
  55. assert res.status_code == 200
  56. prompt = res.body["prompt"]
  57. if add_generation_prompt:
  58. assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
  59. else:
  60. assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"