gguf-new-metadata.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/env python3
  2. import logging
  3. import argparse
  4. import os
  5. import sys
  6. import json
  7. from pathlib import Path
  8. import numpy as np
  9. from typing import Any, Mapping, Sequence
  10. # Necessary to load the local gguf package
  11. if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
  12. sys.path.insert(0, str(Path(__file__).parent.parent))
  13. import gguf
  14. logger = logging.getLogger("gguf-new-metadata")
  15. def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
  16. if np.uint32(1) == np.uint32(1).newbyteorder("<"):
  17. # Host is little endian
  18. host_endian = gguf.GGUFEndian.LITTLE
  19. swapped_endian = gguf.GGUFEndian.BIG
  20. else:
  21. # Sorry PDP or other weird systems that don't use BE or LE.
  22. host_endian = gguf.GGUFEndian.BIG
  23. swapped_endian = gguf.GGUFEndian.LITTLE
  24. if reader.byte_order == "S":
  25. return swapped_endian
  26. else:
  27. return host_endian
  28. def decode_field(field: gguf.ReaderField) -> Any:
  29. if field and field.types:
  30. main_type = field.types[0]
  31. if main_type == gguf.GGUFValueType.ARRAY:
  32. sub_type = field.types[-1]
  33. if sub_type == gguf.GGUFValueType.STRING:
  34. return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
  35. else:
  36. return [pv for idx in field.data for pv in field.parts[idx].tolist()]
  37. if main_type == gguf.GGUFValueType.STRING:
  38. return str(bytes(field.parts[-1]), encoding='utf8')
  39. else:
  40. return field.parts[-1][0]
  41. return None
  42. def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
  43. field = reader.get_field(key)
  44. return decode_field(field)
  45. def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
  46. for field in reader.fields.values():
  47. # Suppress virtual fields and fields written by GGUFWriter
  48. if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
  49. logger.debug(f'Suppressing {field.name}')
  50. continue
  51. # Skip old chat templates if we have new ones
  52. if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
  53. logger.debug(f'Skipping {field.name}')
  54. continue
  55. if field.name in remove_metadata:
  56. logger.debug(f'Removing {field.name}')
  57. continue
  58. old_val = decode_field(field)
  59. val = new_metadata.get(field.name, old_val)
  60. if field.name in new_metadata:
  61. logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
  62. del new_metadata[field.name]
  63. elif val is not None:
  64. logger.debug(f'Copying {field.name}')
  65. if val is not None:
  66. writer.add_key(field.name)
  67. writer.add_val(val, field.types[0])
  68. if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
  69. logger.debug('Adding chat template(s)')
  70. writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE])
  71. del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
  72. # TODO: Support other types than string?
  73. for key, val in new_metadata.items():
  74. logger.debug(f'Adding {key}: {val}')
  75. writer.add_key(key)
  76. writer.add_val(val, gguf.GGUFValueType.STRING)
  77. for tensor in reader.tensors:
  78. # Dimensions are written in reverse order, so flip them first
  79. shape = np.flipud(tensor.shape)
  80. writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
  81. writer.write_header_to_file()
  82. writer.write_kv_data_to_file()
  83. writer.write_ti_data_to_file()
  84. for tensor in reader.tensors:
  85. writer.write_tensor_data(tensor.data)
  86. writer.close()
  87. def main() -> None:
  88. parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
  89. parser.add_argument("input", type=Path, help="GGUF format model input filename")
  90. parser.add_argument("output", type=Path, help="GGUF format model output filename")
  91. parser.add_argument("--general-name", type=str, help="The models general.name")
  92. parser.add_argument("--general-description", type=str, help="The models general.description")
  93. parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)")
  94. parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)")
  95. parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model")
  96. parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
  97. parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
  98. args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
  99. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  100. new_metadata = {}
  101. remove_metadata = args.remove_metadata or []
  102. if args.general_name:
  103. new_metadata[gguf.Keys.General.NAME] = args.general_name
  104. if args.general_description:
  105. new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description
  106. if args.chat_template:
  107. new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template
  108. if args.chat_template_config:
  109. with open(args.chat_template_config, 'r') as fp:
  110. config = json.load(fp)
  111. template = config.get('chat_template')
  112. if template:
  113. new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template
  114. if remove_metadata:
  115. logger.warning('*** Warning *** Warning *** Warning **')
  116. logger.warning('* Most metadata is required for a fully functional GGUF file,')
  117. logger.warning('* removing crucial metadata may result in a corrupt output file!')
  118. if not args.force:
  119. logger.warning('* Enter exactly YES if you are positive you want to proceed:')
  120. response = input('YES, I am sure> ')
  121. if response != 'YES':
  122. logger.info("You didn't enter YES. Okay then, see ya!")
  123. sys.exit(0)
  124. logger.info(f'* Loading: {args.input}')
  125. reader = gguf.GGUFReader(args.input, 'r')
  126. arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
  127. endianess = get_byteorder(reader)
  128. if os.path.isfile(args.output) and not args.force:
  129. logger.warning('*** Warning *** Warning *** Warning **')
  130. logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
  131. logger.warning('* Enter exactly YES if you are positive you want to proceed:')
  132. response = input('YES, I am sure> ')
  133. if response != 'YES':
  134. logger.info("You didn't enter YES. Okay then, see ya!")
  135. sys.exit(0)
  136. logger.info(f'* Writing: {args.output}')
  137. writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess)
  138. alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
  139. if alignment is not None:
  140. logger.debug(f'Setting custom alignment: {alignment}')
  141. writer.data_alignment = alignment
  142. copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
  143. if __name__ == '__main__':
  144. main()