1
0

parse_tool_calls.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package chat
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. )
  8. func StripThinking(s string) (reasoning string, content string) {
  9. start := strings.Index(s, "<think>")
  10. end := strings.Index(s, "</think>")
  11. if start >= 0 && end > start {
  12. reasoning = s[start+len("<think>"):end]
  13. content = s[:start] + s[end+len("</think>"):]
  14. return strings.Trim(reasoning, "\n"), strings.TrimLeft(content, "\n")
  15. }
  16. // If thinking starts but doesn't end (truncated generation), drop it from visible output.
  17. if start >= 0 && end < 0 {
  18. reasoning = s[start+len("<think>"):]
  19. content = s[:start]
  20. return strings.Trim(reasoning, "\n"), strings.TrimSpace(content)
  21. }
  22. return "", s
  23. }
  24. func ExtractToolCalls(s string) (content string, calls []ToolCall, err error) {
  25. var out strings.Builder
  26. rest := s
  27. for {
  28. start := strings.Index(rest, "<tool_call>")
  29. if start < 0 {
  30. out.WriteString(rest)
  31. break
  32. }
  33. out.WriteString(rest[:start])
  34. rest = rest[start+len("<tool_call>"):]
  35. end := strings.Index(rest, "</tool_call>")
  36. if end < 0 {
  37. return "", nil, fmt.Errorf("unterminated <tool_call>")
  38. }
  39. block := strings.TrimSpace(rest[:end])
  40. rest = rest[end+len("</tool_call>"):]
  41. // Some models include trailing commas/newlines; keep robust.
  42. block = strings.Trim(block, "\n\r\t ")
  43. var raw struct {
  44. Name string `json:"name"`
  45. Arguments json.RawMessage `json:"arguments"`
  46. }
  47. dec := json.NewDecoder(bytes.NewReader([]byte(block)))
  48. dec.DisallowUnknownFields()
  49. if err := dec.Decode(&raw); err != nil {
  50. // Fallback: allow unknown fields.
  51. if err2 := json.Unmarshal([]byte(block), &raw); err2 != nil {
  52. return "", nil, fmt.Errorf("parse tool_call json: %w", err)
  53. }
  54. }
  55. if raw.Name == "" {
  56. return "", nil, fmt.Errorf("tool_call missing name")
  57. }
  58. calls = append(calls, ToolCall{Name: raw.Name, Arguments: raw.Arguments})
  59. }
  60. return out.String(), calls, nil
  61. }