convert-persimmon-to-gguf.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import logging
  4. import argparse
  5. import os
  6. import sys
  7. from pathlib import Path
  8. from pprint import pprint
  9. import torch
  10. from sentencepiece import SentencePieceProcessor
  11. if 'NO_LOCAL_GGUF' not in os.environ:
  12. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  13. import gguf
  14. logger = logging.getLogger("persimmon-to-gguf")
  15. def _flatten_dict(dct, tensors, prefix=None):
  16. assert isinstance(dct, dict)
  17. for key in dct.keys():
  18. new_prefix = prefix + '.' + key if prefix is not None else key
  19. if isinstance(dct[key], torch.Tensor):
  20. tensors[new_prefix] = dct[key]
  21. elif isinstance(dct[key], dict):
  22. _flatten_dict(dct[key], tensors, new_prefix)
  23. else:
  24. raise ValueError(type(dct[key]))
  25. return None
  26. def _get_sentencepiece_tokenizer_info(dir_model: Path):
  27. tokenizer_path = dir_model / 'adept_vocab.model'
  28. logger.info('getting sentencepiece tokenizer from', tokenizer_path)
  29. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  30. logger.info('adding tokens')
  31. tokens: list[bytes] = []
  32. scores: list[float] = []
  33. toktypes: list[int] = []
  34. for i in range(tokenizer.vocab_size()):
  35. text: bytes
  36. score: float
  37. piece = tokenizer.id_to_piece(i)
  38. text = piece.encode("utf-8")
  39. score = tokenizer.get_score(i)
  40. toktype = 1
  41. if tokenizer.is_unknown(i):
  42. toktype = 2
  43. if tokenizer.is_control(i):
  44. toktype = 3
  45. if tokenizer.is_unused(i):
  46. toktype = 5
  47. if tokenizer.is_byte(i):
  48. toktype = 6
  49. tokens.append(text)
  50. scores.append(score)
  51. toktypes.append(toktype)
  52. pass
  53. return tokens, scores, toktypes
  54. def main():
  55. parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
  56. parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
  57. parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
  58. parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
  59. parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
  60. parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
  61. args = parser.parse_args()
  62. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  63. sys.path.append(str(args.adept_inference_dir))
  64. persimmon_model = torch.load(args.ckpt_path)
  65. hparams = persimmon_model['args']
  66. pprint(hparams)
  67. tensors: dict[str, torch.Tensor] = {}
  68. _flatten_dict(persimmon_model['model'], tensors, None)
  69. arch = gguf.MODEL_ARCH.PERSIMMON
  70. gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])
  71. block_count = hparams.num_layers
  72. head_count = hparams.num_attention_heads
  73. head_count_kv = head_count
  74. ctx_length = hparams.seq_length
  75. hidden_size = hparams.hidden_size
  76. gguf_writer.add_name('persimmon-8b-chat')
  77. gguf_writer.add_context_length(ctx_length)
  78. gguf_writer.add_embedding_length(hidden_size)
  79. gguf_writer.add_block_count(block_count)
  80. gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
  81. # ref: https://github.com/ggerganov/llama.cpp/pull/4889/commits/eea19039fc52ea2dbd1aab45b59ab4e3e29a3443
  82. gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  83. gguf_writer.add_head_count(head_count)
  84. gguf_writer.add_head_count_kv(head_count_kv)
  85. gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
  86. gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)
  87. tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
  88. gguf_writer.add_tokenizer_model('llama')
  89. gguf_writer.add_tokenizer_pre('default')
  90. gguf_writer.add_token_list(tokens)
  91. gguf_writer.add_token_scores(scores)
  92. gguf_writer.add_token_types(toktypes)
  93. gguf_writer.add_bos_token_id(71013)
  94. gguf_writer.add_eos_token_id(71013)
  95. tensor_map = gguf.get_tensor_name_map(arch, block_count)
  96. logger.info(tensor_map)
  97. for name in tensors.keys():
  98. data_torch = tensors[name]
  99. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  100. continue
  101. old_dtype = data_torch.dtype
  102. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  103. data = data_torch.to(torch.float32).squeeze().numpy()
  104. new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
  105. if new_name is None:
  106. raise ValueError(f"Can not map tensor '{name}'")
  107. n_dims = len(data.shape)
  108. logger.debug(f"{new_name}, n_dims = {str(n_dims)}, {str(old_dtype)} --> {str(data.dtype)}")
  109. gguf_writer.add_tensor(new_name, data)
  110. logger.info("gguf: write header")
  111. gguf_writer.write_header_to_file()
  112. logger.info("gguf: write metadata")
  113. gguf_writer.write_kv_data_to_file()
  114. logger.info("gguf: write tensors")
  115. gguf_writer.write_tensors_to_file()
  116. gguf_writer.close()
  117. logger.info(f"gguf: model successfully exported to '{args.outfile}'")
  118. if __name__ == '__main__':
  119. main()