get_chat_template.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/env python
  2. '''
  3. Fetches the Jinja chat template of a HuggingFace model.
  4. If a model has multiple chat templates, you can specify the variant name.
  5. Syntax:
  6. ./scripts/get_chat_template.py model_id [variant]
  7. Examples:
  8. ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
  9. ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
  10. ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct
  11. '''
  12. import json
  13. import re
  14. import sys
  15. def get_chat_template(model_id, variant=None):
  16. try:
  17. # Use huggingface_hub library if available.
  18. # Allows access to gated models if the user has access and ran `huggingface-cli login`.
  19. from huggingface_hub import hf_hub_download
  20. with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
  21. config_str = f.read()
  22. except ImportError:
  23. import requests
  24. assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
  25. response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
  26. if response.status_code == 401:
  27. raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
  28. response.raise_for_status()
  29. config_str = response.text
  30. try:
  31. config = json.loads(config_str)
  32. except json.JSONDecodeError:
  33. # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
  34. # (Remove extra '}' near the end of the file)
  35. config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
  36. chat_template = config['chat_template']
  37. if isinstance(chat_template, str):
  38. return chat_template
  39. else:
  40. variants = {
  41. ct['name']: ct['template']
  42. for ct in chat_template
  43. }
  44. def format_variants():
  45. return ', '.join(f'"{v}"' for v in variants.keys())
  46. if variant is None:
  47. if 'default' not in variants:
  48. raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
  49. variant = 'default'
  50. sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
  51. elif variant not in variants:
  52. raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
  53. return variants[variant]
  54. def main(args):
  55. if len(args) < 1:
  56. raise ValueError("Please provide a model ID and an optional variant name")
  57. model_id = args[0]
  58. variant = None if len(args) < 2 else args[1]
  59. template = get_chat_template(model_id, variant)
  60. sys.stdout.write(template)
  61. if __name__ == '__main__':
  62. main(sys.argv[1:])