| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- package chat
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "strings"
- )
- func StripThinking(s string) (reasoning string, content string) {
- start := strings.Index(s, "<think>")
- end := strings.Index(s, "</think>")
- if start >= 0 && end > start {
- reasoning = s[start+len("<think>"):end]
- content = s[:start] + s[end+len("</think>"):]
- return strings.Trim(reasoning, "\n"), strings.TrimLeft(content, "\n")
- }
- // If thinking starts but doesn't end (truncated generation), drop it from visible output.
- if start >= 0 && end < 0 {
- reasoning = s[start+len("<think>"):]
- content = s[:start]
- return strings.Trim(reasoning, "\n"), strings.TrimSpace(content)
- }
- return "", s
- }
- func ExtractToolCalls(s string) (content string, calls []ToolCall, err error) {
- var out strings.Builder
- rest := s
- for {
- start := strings.Index(rest, "<tool_call>")
- if start < 0 {
- out.WriteString(rest)
- break
- }
- out.WriteString(rest[:start])
- rest = rest[start+len("<tool_call>"):]
- end := strings.Index(rest, "</tool_call>")
- if end < 0 {
- return "", nil, fmt.Errorf("unterminated <tool_call>")
- }
- block := strings.TrimSpace(rest[:end])
- rest = rest[end+len("</tool_call>"):]
- // Some models include trailing commas/newlines; keep robust.
- block = strings.Trim(block, "\n\r\t ")
- var raw struct {
- Name string `json:"name"`
- Arguments json.RawMessage `json:"arguments"`
- }
- dec := json.NewDecoder(bytes.NewReader([]byte(block)))
- dec.DisallowUnknownFields()
- if err := dec.Decode(&raw); err != nil {
- // Fallback: allow unknown fields.
- if err2 := json.Unmarshal([]byte(block), &raw); err2 != nil {
- return "", nil, fmt.Errorf("parse tool_call json: %w", err)
- }
- }
- if raw.Name == "" {
- return "", nil, fmt.Errorf("tool_call missing name")
- }
- calls = append(calls, ToolCall{Name: raw.Name, Arguments: raw.Arguments})
- }
- return out.String(), calls, nil
- }
|