render_qwen3.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package chat
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. )
  7. // Qwen3Renderer implements the Qwen3 chat template behavior commonly shipped
  8. // in tokenizer_config.json (including tool-call wrappers).
  9. // It is a deterministic renderer (not a general Jinja interpreter).
  10. //
  11. //nolint:revive // exported API
  12. type Qwen3Renderer struct{}
  13. func (r *Qwen3Renderer) Render(messages []Message, opts Options) (string, error) {
  14. var sb strings.Builder
  15. // Tools header block (only when tools are provided)
  16. if len(opts.Tools) > 0 {
  17. sb.WriteString("<|im_start|>system\n")
  18. if len(messages) > 0 && messages[0].Role == "system" {
  19. sb.WriteString(messages[0].Content)
  20. sb.WriteString("\n\n")
  21. }
  22. sb.WriteString("# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>")
  23. for _, tool := range opts.Tools {
  24. b, err := json.Marshal(tool)
  25. if err != nil {
  26. return "", fmt.Errorf("marshal tool: %w", err)
  27. }
  28. sb.WriteString("\n")
  29. sb.Write(b)
  30. }
  31. sb.WriteString("\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n")
  32. } else {
  33. // No-tools path: include system message only if first message is system.
  34. if len(messages) > 0 && messages[0].Role == "system" {
  35. sb.WriteString("<|im_start|>system\n")
  36. sb.WriteString(messages[0].Content)
  37. sb.WriteString("<|im_end|>\n")
  38. }
  39. }
  40. // Find last user query index for thinking rendering.
  41. lastQueryIdx := len(messages) - 1
  42. multiStepTool := true
  43. for i := len(messages) - 1; i >= 0; i-- {
  44. m := messages[i]
  45. if multiStepTool && m.Role == "user" && !strings.HasPrefix(m.Content, "<tool_response>") {
  46. multiStepTool = false
  47. lastQueryIdx = i
  48. break
  49. }
  50. }
  51. // Render message stream.
  52. inToolGroup := false
  53. for i := 0; i < len(messages); i++ {
  54. m := messages[i]
  55. switch m.Role {
  56. case "tool":
  57. // Tool responses are wrapped as a fake user message with <tool_response> blocks.
  58. if !inToolGroup {
  59. sb.WriteString("<|im_start|>user")
  60. inToolGroup = true
  61. }
  62. sb.WriteString("\n<tool_response>\n")
  63. sb.WriteString(m.Content)
  64. sb.WriteString("\n</tool_response>")
  65. // Close tool group if next isn't tool.
  66. if i == len(messages)-1 || messages[i+1].Role != "tool" {
  67. sb.WriteString("<|im_end|>\n")
  68. inToolGroup = false
  69. }
  70. continue
  71. default:
  72. if inToolGroup {
  73. // Safety: close an unterminated tool group.
  74. sb.WriteString("<|im_end|>\n")
  75. inToolGroup = false
  76. }
  77. }
  78. content := m.Content
  79. if m.Role == "user" || (m.Role == "system" && i != 0) {
  80. sb.WriteString("<|im_start|>")
  81. sb.WriteString(m.Role)
  82. sb.WriteString("\n")
  83. sb.WriteString(content)
  84. sb.WriteString("<|im_end|>\n")
  85. continue
  86. }
  87. if m.Role == "assistant" {
  88. sb.WriteString("<|im_start|>assistant\n")
  89. reasoning := m.ReasoningContent
  90. if opts.EnableThinking {
  91. // Render thinking only after the last user query index.
  92. if i > lastQueryIdx {
  93. sb.WriteString("<think>\n")
  94. sb.WriteString(strings.Trim(reasoning, "\n"))
  95. sb.WriteString("\n</think>\n\n")
  96. }
  97. }
  98. sb.WriteString(strings.TrimLeft(content, "\n"))
  99. // Render tool calls, if any.
  100. if len(m.ToolCalls) > 0 {
  101. for j, tc := range m.ToolCalls {
  102. if (j == 0 && content != "") || j > 0 {
  103. sb.WriteString("\n")
  104. }
  105. sb.WriteString("<tool_call>\n{\"name\": \"")
  106. sb.WriteString(tc.Name)
  107. sb.WriteString("\", \"arguments\": ")
  108. if len(tc.Arguments) == 0 {
  109. sb.WriteString("{}")
  110. } else {
  111. sb.Write(tc.Arguments)
  112. }
  113. sb.WriteString("}\n</tool_call>")
  114. }
  115. }
  116. sb.WriteString("<|im_end|>\n")
  117. continue
  118. }
  119. }
  120. if opts.AddGenerationPrompt {
  121. sb.WriteString("<|im_start|>assistant\n")
  122. }
  123. return sb.String(), nil
  124. }