|
|
@@ -47,3 +47,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
|
|
|
|
|
today_str = datetime.date.today().strftime(format)
|
|
|
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("add_generation_prompt", [False, True])
|
|
|
+@pytest.mark.parametrize("template_name,expected_generation_prompt", [
|
|
|
+ ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
|
|
|
+])
|
|
|
+def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
|
|
|
+ global server
|
|
|
+ server.jinja = True
|
|
|
+ server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
|
|
+ server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
|
|
+
|
|
|
+ res = server.make_request("POST", "/apply-template", data={
|
|
|
+ "messages": [
|
|
|
+ {"role": "user", "content": "What is today?"},
|
|
|
+ ],
|
|
|
+ "add_generation_prompt": add_generation_prompt,
|
|
|
+ })
|
|
|
+ assert res.status_code == 200
|
|
|
+ prompt = res.body["prompt"]
|
|
|
+
|
|
|
+ if add_generation_prompt:
|
|
|
+ assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
|
|
+ else:
|
|
|
+ assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
|