test-chat-template.cpp 4.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #include <iostream>
  2. #include <string>
  3. #include <vector>
  4. #include <sstream>
  5. #undef NDEBUG
  6. #include <cassert>
  7. #include "llama.h"
  8. int main(void) {
  9. llama_chat_message conversation[] = {
  10. {"system", "You are a helpful assistant"},
  11. {"user", "Hello"},
  12. {"assistant", "Hi there"},
  13. {"user", "Who are you"},
  14. {"assistant", " I am an assistant "},
  15. {"user", "Another question"},
  16. };
  17. size_t message_count = 6;
  18. std::vector<std::string> templates = {
  19. // teknium/OpenHermes-2.5-Mistral-7B
  20. "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
  21. // mistralai/Mistral-7B-Instruct-v0.2
  22. "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
  23. // TheBloke/FusionNet_34Bx2_MoE-AWQ
  24. "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
  25. // bofenghuang/vigogne-2-70b-chat
  26. "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
  27. };
  28. std::vector<std::string> expected_substr = {
  29. "<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
  30. "[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
  31. "</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
  32. "[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
  33. };
  34. std::vector<char> formatted_chat(1024);
  35. int32_t res;
  36. // test invalid chat template
  37. res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
  38. assert(res < 0);
  39. for (size_t i = 0; i < templates.size(); i++) {
  40. std::string custom_template = templates[i];
  41. std::string substr = expected_substr[i];
  42. formatted_chat.resize(1024);
  43. res = llama_chat_apply_template(
  44. nullptr,
  45. custom_template.c_str(),
  46. conversation,
  47. message_count,
  48. true,
  49. formatted_chat.data(),
  50. formatted_chat.size()
  51. );
  52. formatted_chat.resize(res);
  53. std::string output(formatted_chat.data(), formatted_chat.size());
  54. std::cout << output << "\n-------------------------\n";
  55. // expect the "formatted_chat" to contain pre-defined strings
  56. assert(output.find(substr) != std::string::npos);
  57. }
  58. return 0;
  59. }